summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-01-20 22:44:42 -0500
committerCharles Harris <charlesr.harris@gmail.com>2015-01-20 22:44:42 -0500
commit960433e8f5f39587d10c60ed1d3f50591434a82b (patch)
tree4b57f21d0e1312adb349dd74a4558a3b3dba0e02 /numpy
parente73d4fcb2a0052572e6c1efaffa2b05f5931956e (diff)
parenta7fdf04f2a527055afe53dfaffaca09931b12a2d (diff)
downloadnumpy-960433e8f5f39587d10c60ed1d3f50591434a82b.tar.gz
Merge pull request #5468 from jaimefrio/swapaxes_view
ENH: Make swapaxes always return a view. Fixes #5260
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/fromnumeric.py6
-rw-r--r--numpy/core/src/multiarray/shape.c40
-rw-r--r--numpy/core/tests/test_multiarray.py32
3 files changed, 48 insertions, 30 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index 2a527a4a4..aef09411a 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -464,8 +464,10 @@ def swapaxes(a, axis1, axis2):
Returns
-------
a_swapped : ndarray
- If `a` is an ndarray, then a view of `a` is returned; otherwise
- a new array is created.
+ For Numpy >= 1.10, if `a` is an ndarray, then a view of `a` is
+ returned; otherwise a new array is created. For earlier Numpy
+ versions a view of `a` is returned only if the order of the
+ axes is changed, otherwise the input array is returned.
Examples
--------
diff --git a/numpy/core/src/multiarray/shape.c b/numpy/core/src/multiarray/shape.c
index df1874594..f1e81ff6b 100644
--- a/numpy/core/src/multiarray/shape.c
+++ b/numpy/core/src/multiarray/shape.c
@@ -653,19 +653,8 @@ PyArray_SwapAxes(PyArrayObject *ap, int a1, int a2)
{
PyArray_Dims new_axes;
npy_intp dims[NPY_MAXDIMS];
- int n, i, val;
- PyObject *ret;
-
- if (a1 == a2) {
- Py_INCREF(ap);
- return (PyObject *)ap;
- }
-
- n = PyArray_NDIM(ap);
- if (n <= 1) {
- Py_INCREF(ap);
- return (PyObject *)ap;
- }
+ int n = PyArray_NDIM(ap);
+ int i;
if (a1 < 0) {
a1 += n;
@@ -683,25 +672,20 @@ PyArray_SwapAxes(PyArrayObject *ap, int a1, int a2)
"bad axis2 argument to swapaxes");
return NULL;
}
+
+ for (i = 0; i < n; ++i) {
+ dims[i] = i;
+ }
+ dims[a1] = a2;
+ dims[a2] = a1;
+
new_axes.ptr = dims;
new_axes.len = n;
- for (i = 0; i < n; i++) {
- if (i == a1) {
- val = a2;
- }
- else if (i == a2) {
- val = a1;
- }
- else {
- val = i;
- }
- new_axes.ptr[i] = val;
- }
- ret = PyArray_Transpose(ap, &new_axes);
- return ret;
+ return PyArray_Transpose(ap, &new_axes);
}
+
/*NUMPY_API
* Return Transpose.
*/
@@ -969,7 +953,7 @@ PyArray_Ravel(PyArrayObject *arr, NPY_ORDER order)
PyArray_CreateSortedStridePerm(PyArray_NDIM(arr),
PyArray_STRIDES(arr), strideperm);
-
+
for (i = ndim-1; i >= 0; --i) {
if (PyArray_DIM(arr, strideperm[i].perm) == 1) {
/* A size one dimension does not matter */
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 7ecb56cae..5138635d6 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1887,6 +1887,38 @@ class TestMethods(TestCase):
assert_equal(a.ravel('A'), [0, 2, 4, 6, 8, 10, 12, 14])
assert_equal(a.ravel('F'), [0, 8, 4, 12, 2, 10, 6, 14])
+ def test_swapaxes(self):
+ a = np.arange(1*2*3*4).reshape(1, 2, 3, 4).copy()
+ idx = np.indices(a.shape)
+ assert_(a.flags['OWNDATA'])
+ b = a.copy()
+ # check exceptions
+ assert_raises(ValueError, a.swapaxes, -5, 0)
+ assert_raises(ValueError, a.swapaxes, 4, 0)
+ assert_raises(ValueError, a.swapaxes, 0, -5)
+ assert_raises(ValueError, a.swapaxes, 0, 4)
+
+ for i in range(-4, 4):
+ for j in range(-4, 4):
+ for k, src in enumerate((a, b)):
+ c = src.swapaxes(i, j)
+ # check shape
+ shape = list(src.shape)
+ shape[i] = src.shape[j]
+ shape[j] = src.shape[i]
+ assert_equal(c.shape, shape, str((i, j, k)))
+ # check array contents
+ i0, i1, i2, i3 = [dim-1 for dim in c.shape]
+ j0, j1, j2, j3 = [dim-1 for dim in src.shape]
+ assert_equal(src[idx[j0], idx[j1], idx[j2], idx[j3]],
+ c[idx[i0], idx[i1], idx[i2], idx[i3]],
+ str((i, j, k)))
+ # check a view is always returned, gh-5260
+ assert_(not c.flags['OWNDATA'], str((i, j, k)))
+ # check on non-contiguous input array
+ if k == 1:
+ b = c
+
def test_conjugate(self):
a = np.array([1-1j, 1+1j, 23+23.0j])
ac = a.conj()