summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime <jaime.frio@gmail.com>2015-01-19 09:10:22 -0800
committerJaime <jaime.frio@gmail.com>2015-01-19 09:10:22 -0800
commit4ed1587a7de85b4fa01dff8ef6e0e901a25f149c (patch)
tree74083e9a2a04521596d8c42b016482512d516f28
parent8a81c08e777996b933fe4568ee5a6a0e01416faf (diff)
parentcf41fceb22df5e6f2f48108a4beb4325e5f8b7fa (diff)
downloadnumpy-4ed1587a7de85b4fa01dff8ef6e0e901a25f149c.tar.gz
Merge pull request #5464 from charris/rollaxis-always-return-view
Rollaxis always return view
-rw-r--r--doc/release/1.10.0-notes.rst6
-rw-r--r--numpy/core/numeric.py17
-rw-r--r--numpy/core/tests/test_numeric.py63
3 files changed, 80 insertions, 6 deletions
diff --git a/doc/release/1.10.0-notes.rst b/doc/release/1.10.0-notes.rst
index eba709f6b..26559ad32 100644
--- a/doc/release/1.10.0-notes.rst
+++ b/doc/release/1.10.0-notes.rst
@@ -50,6 +50,12 @@ the case of matrices. Matrices are special cased for backward
compatibility and still return 1-D arrays as before. If you need to
preserve the matrix subtype, use the methods instead of the functions.
+*rollaxis* always returns a view
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Previously, a view was returned except when no change was made in the order
+of the axes, in which case the input array was returned. A view is now
+returned in all cases.
+
New Features
============
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 430f7a715..a464562c4 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1392,6 +1392,7 @@ def roll(a, shift, axis=None):
res = res.reshape(a.shape)
return res
+
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
@@ -1410,7 +1411,9 @@ def rollaxis(a, axis, start=0):
Returns
-------
res : ndarray
- Output array.
+ For Numpy >= 1.10 a view of `a` is always returned. For earlier
+ Numpy versions a view of `a` is returned only if the order of the
+ axes is changed, otherwise the input array is returned.
See Also
--------
@@ -1436,17 +1439,19 @@ def rollaxis(a, axis, start=0):
msg = 'rollaxis: %s (%d) must be >=0 and < %d'
if not (0 <= axis < n):
raise ValueError(msg % ('axis', axis, n))
- if not (0 <= start < n+1):
- raise ValueError(msg % ('start', start, n+1))
- if (axis < start): # it's been removed
+ if not (0 <= start < n + 1):
+ raise ValueError(msg % ('start', start, n + 1))
+ if (axis < start):
+ # it's been removed
start -= 1
- if axis==start:
- return a
+ if axis == start:
+ return a[...]
axes = list(range(0, n))
axes.remove(axis)
axes.insert(start, axis)
return a.transpose(axes)
+
# fix hack in scipy which imports this function
def _move_axis_to_0(a, axis):
return rollaxis(a, axis, 0)
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index b151e24f3..7948b1355 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2065,6 +2065,69 @@ class TestRoll(TestCase):
x = np.array([])
assert_equal(np.roll(x, 1), np.array([]))
+
+class TestRollaxis(TestCase):
+
+ # expected shape indexed by (axis, start) for array of
+ # shape (1, 2, 3, 4)
+ tgtshape = {(0, 0): (1, 2, 3, 4), (0, 1): (1, 2, 3, 4),
+ (0, 2): (2, 1, 3, 4), (0, 3): (2, 3, 1, 4),
+ (0, 4): (2, 3, 4, 1),
+ (1, 0): (2, 1, 3, 4), (1, 1): (1, 2, 3, 4),
+ (1, 2): (1, 2, 3, 4), (1, 3): (1, 3, 2, 4),
+ (1, 4): (1, 3, 4, 2),
+ (2, 0): (3, 1, 2, 4), (2, 1): (1, 3, 2, 4),
+ (2, 2): (1, 2, 3, 4), (2, 3): (1, 2, 3, 4),
+ (2, 4): (1, 2, 4, 3),
+ (3, 0): (4, 1, 2, 3), (3, 1): (1, 4, 2, 3),
+ (3, 2): (1, 2, 4, 3), (3, 3): (1, 2, 3, 4),
+ (3, 4): (1, 2, 3, 4)}
+
+ def test_exceptions(self):
+ a = arange(1*2*3*4).reshape(1, 2, 3, 4)
+ assert_raises(ValueError, rollaxis, a, -5, 0)
+ assert_raises(ValueError, rollaxis, a, 0, -5)
+ assert_raises(ValueError, rollaxis, a, 4, 0)
+ assert_raises(ValueError, rollaxis, a, 0, 5)
+
+ def test_results(self):
+ a = arange(1*2*3*4).reshape(1, 2, 3, 4).copy()
+ aind = np.indices(a.shape)
+ assert_(a.flags['OWNDATA'])
+ for (i, j) in self.tgtshape:
+ # positive axis, positive start
+ res = rollaxis(a, axis=i, start=j)
+ i0, i1, i2, i3 = aind[np.array(res.shape) - 1]
+ assert_(np.all(res[i0, i1, i2, i3] == a))
+ assert_(res.shape == self.tgtshape[(i, j)], str((i,j)))
+ assert_(not res.flags['OWNDATA'])
+
+ # negative axis, positive start
+ ip = i + 1
+ res = rollaxis(a, axis=-ip, start=j)
+ i0, i1, i2, i3 = aind[np.array(res.shape) - 1]
+ assert_(np.all(res[i0, i1, i2, i3] == a))
+ assert_(res.shape == self.tgtshape[(4 - ip, j)])
+ assert_(not res.flags['OWNDATA'])
+
+ # positive axis, negative start
+ jp = j + 1 if j < 4 else j
+ res = rollaxis(a, axis=i, start=-jp)
+ i0, i1, i2, i3 = aind[np.array(res.shape) - 1]
+ assert_(np.all(res[i0, i1, i2, i3] == a))
+ assert_(res.shape == self.tgtshape[(i, 4 - jp)])
+ assert_(not res.flags['OWNDATA'])
+
+ # negative axis, negative start
+ ip = i + 1
+ jp = j + 1 if j < 4 else j
+ res = rollaxis(a, axis=-ip, start=-jp)
+ i0, i1, i2, i3 = aind[np.array(res.shape) - 1]
+ assert_(np.all(res[i0, i1, i2, i3] == a))
+ assert_(res.shape == self.tgtshape[(4 - ip, 4 - jp)])
+ assert_(not res.flags['OWNDATA'])
+
+
class TestCross(TestCase):
def test_2x2(self):
u = [1, 2]