diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/fromnumeric.py | 6 | ||||
-rw-r--r-- | numpy/core/src/multiarray/shape.c | 40 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 32 |
3 files changed, 48 insertions, 30 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py index 2a527a4a4..aef09411a 100644 --- a/numpy/core/fromnumeric.py +++ b/numpy/core/fromnumeric.py @@ -464,8 +464,10 @@ def swapaxes(a, axis1, axis2): Returns ------- a_swapped : ndarray - If `a` is an ndarray, then a view of `a` is returned; otherwise - a new array is created. + For Numpy >= 1.10, if `a` is an ndarray, then a view of `a` is + returned; otherwise a new array is created. For earlier Numpy + versions a view of `a` is returned only if the order of the + axes is changed, otherwise the input array is returned. Examples -------- diff --git a/numpy/core/src/multiarray/shape.c b/numpy/core/src/multiarray/shape.c index df1874594..f1e81ff6b 100644 --- a/numpy/core/src/multiarray/shape.c +++ b/numpy/core/src/multiarray/shape.c @@ -653,19 +653,8 @@ PyArray_SwapAxes(PyArrayObject *ap, int a1, int a2) { PyArray_Dims new_axes; npy_intp dims[NPY_MAXDIMS]; - int n, i, val; - PyObject *ret; - - if (a1 == a2) { - Py_INCREF(ap); - return (PyObject *)ap; - } - - n = PyArray_NDIM(ap); - if (n <= 1) { - Py_INCREF(ap); - return (PyObject *)ap; - } + int n = PyArray_NDIM(ap); + int i; if (a1 < 0) { a1 += n; @@ -683,25 +672,20 @@ PyArray_SwapAxes(PyArrayObject *ap, int a1, int a2) "bad axis2 argument to swapaxes"); return NULL; } + + for (i = 0; i < n; ++i) { + dims[i] = i; + } + dims[a1] = a2; + dims[a2] = a1; + new_axes.ptr = dims; new_axes.len = n; - for (i = 0; i < n; i++) { - if (i == a1) { - val = a2; - } - else if (i == a2) { - val = a1; - } - else { - val = i; - } - new_axes.ptr[i] = val; - } - ret = PyArray_Transpose(ap, &new_axes); - return ret; + return PyArray_Transpose(ap, &new_axes); } + /*NUMPY_API * Return Transpose. */ @@ -969,7 +953,7 @@ PyArray_Ravel(PyArrayObject *arr, NPY_ORDER order) PyArray_CreateSortedStridePerm(PyArray_NDIM(arr), PyArray_STRIDES(arr), strideperm); - + for (i = ndim-1; i >= 0; --i) { if (PyArray_DIM(arr, strideperm[i].perm) == 1) { /* A size one dimension does not matter */ diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 7ecb56cae..5138635d6 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1887,6 +1887,38 @@ class TestMethods(TestCase): assert_equal(a.ravel('A'), [0, 2, 4, 6, 8, 10, 12, 14]) assert_equal(a.ravel('F'), [0, 8, 4, 12, 2, 10, 6, 14]) + def test_swapaxes(self): + a = np.arange(1*2*3*4).reshape(1, 2, 3, 4).copy() + idx = np.indices(a.shape) + assert_(a.flags['OWNDATA']) + b = a.copy() + # check exceptions + assert_raises(ValueError, a.swapaxes, -5, 0) + assert_raises(ValueError, a.swapaxes, 4, 0) + assert_raises(ValueError, a.swapaxes, 0, -5) + assert_raises(ValueError, a.swapaxes, 0, 4) + + for i in range(-4, 4): + for j in range(-4, 4): + for k, src in enumerate((a, b)): + c = src.swapaxes(i, j) + # check shape + shape = list(src.shape) + shape[i] = src.shape[j] + shape[j] = src.shape[i] + assert_equal(c.shape, shape, str((i, j, k))) + # check array contents + i0, i1, i2, i3 = [dim-1 for dim in c.shape] + j0, j1, j2, j3 = [dim-1 for dim in src.shape] + assert_equal(src[idx[j0], idx[j1], idx[j2], idx[j3]], + c[idx[i0], idx[i1], idx[i2], idx[i3]], + str((i, j, k))) + # check a view is always returned, gh-5260 + assert_(not c.flags['OWNDATA'], str((i, j, k))) + # check on non-contiguous input array + if k == 1: + b = c + def test_conjugate(self): a = np.array([1-1j, 1+1j, 23+23.0j]) ac = a.conj() |