summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-01-18 19:23:47 -0700
committerCharles Harris <charlesr.harris@gmail.com>2015-01-19 09:14:48 -0700
commitcf41fceb22df5e6f2f48108a4beb4325e5f8b7fa (patch)
tree74083e9a2a04521596d8c42b016482512d516f28
parent38b1a7c3038223ee72138413621ef3e00c9224f5 (diff)
downloadnumpy-cf41fceb22df5e6f2f48108a4beb4325e5f8b7fa.tar.gz
TST: Tests for numeric.rollaxis.
There were no tests previous to this.
-rw-r--r--numpy/core/tests/test_numeric.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index b151e24f3..7948b1355 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2065,6 +2065,69 @@ class TestRoll(TestCase):
x = np.array([])
assert_equal(np.roll(x, 1), np.array([]))
+
+class TestRollaxis(TestCase):
+
+ # expected shape indexed by (axis, start) for array of
+ # shape (1, 2, 3, 4)
+ tgtshape = {(0, 0): (1, 2, 3, 4), (0, 1): (1, 2, 3, 4),
+ (0, 2): (2, 1, 3, 4), (0, 3): (2, 3, 1, 4),
+ (0, 4): (2, 3, 4, 1),
+ (1, 0): (2, 1, 3, 4), (1, 1): (1, 2, 3, 4),
+ (1, 2): (1, 2, 3, 4), (1, 3): (1, 3, 2, 4),
+ (1, 4): (1, 3, 4, 2),
+ (2, 0): (3, 1, 2, 4), (2, 1): (1, 3, 2, 4),
+ (2, 2): (1, 2, 3, 4), (2, 3): (1, 2, 3, 4),
+ (2, 4): (1, 2, 4, 3),
+ (3, 0): (4, 1, 2, 3), (3, 1): (1, 4, 2, 3),
+ (3, 2): (1, 2, 4, 3), (3, 3): (1, 2, 3, 4),
+ (3, 4): (1, 2, 3, 4)}
+
+ def test_exceptions(self):
+ a = arange(1*2*3*4).reshape(1, 2, 3, 4)
+ assert_raises(ValueError, rollaxis, a, -5, 0)
+ assert_raises(ValueError, rollaxis, a, 0, -5)
+ assert_raises(ValueError, rollaxis, a, 4, 0)
+ assert_raises(ValueError, rollaxis, a, 0, 5)
+
+ def test_results(self):
+ a = arange(1*2*3*4).reshape(1, 2, 3, 4).copy()
+ aind = np.indices(a.shape)
+ assert_(a.flags['OWNDATA'])
+ for (i, j) in self.tgtshape:
+ # positive axis, positive start
+ res = rollaxis(a, axis=i, start=j)
+ i0, i1, i2, i3 = aind[np.array(res.shape) - 1]
+ assert_(np.all(res[i0, i1, i2, i3] == a))
+ assert_(res.shape == self.tgtshape[(i, j)], str((i,j)))
+ assert_(not res.flags['OWNDATA'])
+
+ # negative axis, positive start
+ ip = i + 1
+ res = rollaxis(a, axis=-ip, start=j)
+ i0, i1, i2, i3 = aind[np.array(res.shape) - 1]
+ assert_(np.all(res[i0, i1, i2, i3] == a))
+ assert_(res.shape == self.tgtshape[(4 - ip, j)])
+ assert_(not res.flags['OWNDATA'])
+
+ # positive axis, negative start
+ jp = j + 1 if j < 4 else j
+ res = rollaxis(a, axis=i, start=-jp)
+ i0, i1, i2, i3 = aind[np.array(res.shape) - 1]
+ assert_(np.all(res[i0, i1, i2, i3] == a))
+ assert_(res.shape == self.tgtshape[(i, 4 - jp)])
+ assert_(not res.flags['OWNDATA'])
+
+ # negative axis, negative start
+ ip = i + 1
+ jp = j + 1 if j < 4 else j
+ res = rollaxis(a, axis=-ip, start=-jp)
+ i0, i1, i2, i3 = aind[np.array(res.shape) - 1]
+ assert_(np.all(res[i0, i1, i2, i3] == a))
+ assert_(res.shape == self.tgtshape[(4 - ip, 4 - jp)])
+ assert_(not res.flags['OWNDATA'])
+
+
class TestCross(TestCase):
def test_2x2(self):
u = [1, 2]