diff options
author | Jaime <jaime.frio@gmail.com> | 2015-01-19 09:10:22 -0800 |
---|---|---|
committer | Jaime <jaime.frio@gmail.com> | 2015-01-19 09:10:22 -0800 |
commit | 4ed1587a7de85b4fa01dff8ef6e0e901a25f149c (patch) | |
tree | 74083e9a2a04521596d8c42b016482512d516f28 /numpy | |
parent | 8a81c08e777996b933fe4568ee5a6a0e01416faf (diff) | |
parent | cf41fceb22df5e6f2f48108a4beb4325e5f8b7fa (diff) | |
download | numpy-4ed1587a7de85b4fa01dff8ef6e0e901a25f149c.tar.gz |
Merge pull request #5464 from charris/rollaxis-always-return-view
Rollaxis always return view
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/numeric.py | 17 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 63 |
2 files changed, 74 insertions, 6 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 430f7a715..a464562c4 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1392,6 +1392,7 @@ def roll(a, shift, axis=None): res = res.reshape(a.shape) return res + def rollaxis(a, axis, start=0): """ Roll the specified axis backwards, until it lies in a given position. @@ -1410,7 +1411,9 @@ def rollaxis(a, axis, start=0): Returns ------- res : ndarray - Output array. + For Numpy >= 1.10 a view of `a` is always returned. For earlier + Numpy versions a view of `a` is returned only if the order of the + axes is changed, otherwise the input array is returned. See Also -------- @@ -1436,17 +1439,19 @@ def rollaxis(a, axis, start=0): msg = 'rollaxis: %s (%d) must be >=0 and < %d' if not (0 <= axis < n): raise ValueError(msg % ('axis', axis, n)) - if not (0 <= start < n+1): - raise ValueError(msg % ('start', start, n+1)) - if (axis < start): # it's been removed + if not (0 <= start < n + 1): + raise ValueError(msg % ('start', start, n + 1)) + if (axis < start): + # it's been removed start -= 1 - if axis==start: - return a + if axis == start: + return a[...] axes = list(range(0, n)) axes.remove(axis) axes.insert(start, axis) return a.transpose(axes) + # fix hack in scipy which imports this function def _move_axis_to_0(a, axis): return rollaxis(a, axis, 0) diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index b151e24f3..7948b1355 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2065,6 +2065,69 @@ class TestRoll(TestCase): x = np.array([]) assert_equal(np.roll(x, 1), np.array([])) + +class TestRollaxis(TestCase): + + # expected shape indexed by (axis, start) for array of + # shape (1, 2, 3, 4) + tgtshape = {(0, 0): (1, 2, 3, 4), (0, 1): (1, 2, 3, 4), + (0, 2): (2, 1, 3, 4), (0, 3): (2, 3, 1, 4), + (0, 4): (2, 3, 4, 1), + (1, 0): (2, 1, 3, 4), (1, 1): (1, 2, 3, 4), + (1, 2): (1, 2, 3, 4), (1, 3): (1, 3, 2, 4), + (1, 4): (1, 3, 4, 2), + (2, 0): (3, 1, 2, 4), (2, 1): (1, 3, 2, 4), + (2, 2): (1, 2, 3, 4), (2, 3): (1, 2, 3, 4), + (2, 4): (1, 2, 4, 3), + (3, 0): (4, 1, 2, 3), (3, 1): (1, 4, 2, 3), + (3, 2): (1, 2, 4, 3), (3, 3): (1, 2, 3, 4), + (3, 4): (1, 2, 3, 4)} + + def test_exceptions(self): + a = arange(1*2*3*4).reshape(1, 2, 3, 4) + assert_raises(ValueError, rollaxis, a, -5, 0) + assert_raises(ValueError, rollaxis, a, 0, -5) + assert_raises(ValueError, rollaxis, a, 4, 0) + assert_raises(ValueError, rollaxis, a, 0, 5) + + def test_results(self): + a = arange(1*2*3*4).reshape(1, 2, 3, 4).copy() + aind = np.indices(a.shape) + assert_(a.flags['OWNDATA']) + for (i, j) in self.tgtshape: + # positive axis, positive start + res = rollaxis(a, axis=i, start=j) + i0, i1, i2, i3 = aind[np.array(res.shape) - 1] + assert_(np.all(res[i0, i1, i2, i3] == a)) + assert_(res.shape == self.tgtshape[(i, j)], str((i,j))) + assert_(not res.flags['OWNDATA']) + + # negative axis, positive start + ip = i + 1 + res = rollaxis(a, axis=-ip, start=j) + i0, i1, i2, i3 = aind[np.array(res.shape) - 1] + assert_(np.all(res[i0, i1, i2, i3] == a)) + assert_(res.shape == self.tgtshape[(4 - ip, j)]) + assert_(not res.flags['OWNDATA']) + + # positive axis, negative start + jp = j + 1 if j < 4 else j + res = rollaxis(a, axis=i, start=-jp) + i0, i1, i2, i3 = aind[np.array(res.shape) - 1] + assert_(np.all(res[i0, i1, i2, i3] == a)) + assert_(res.shape == self.tgtshape[(i, 4 - jp)]) + assert_(not res.flags['OWNDATA']) + + # negative axis, negative start + ip = i + 1 + jp = j + 1 if j < 4 else j + res = rollaxis(a, axis=-ip, start=-jp) + i0, i1, i2, i3 = aind[np.array(res.shape) - 1] + assert_(np.all(res[i0, i1, i2, i3] == a)) + assert_(res.shape == self.tgtshape[(4 - ip, 4 - jp)]) + assert_(not res.flags['OWNDATA']) + + class TestCross(TestCase): def test_2x2(self): u = [1, 2] |