diff options
Diffstat (limited to 'numpy/lib/tests/test_shape_base.py')
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 01ea028bb..be1604a75 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -289,14 +289,26 @@ class TestExpandDims(object): assert_(b.shape[axis] == 1) assert_(np.squeeze(b).shape == s) - def test_deprecations(self): - # 2017-05-17, 1.13.0 + def test_axis_tuple(self): + a = np.empty((3, 3, 3)) + assert np.expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3) + assert np.expand_dims(a, axis=(0, -1, -2)).shape == (1, 3, 3, 3, 1, 1) + assert np.expand_dims(a, axis=(0, 3, 5)).shape == (1, 3, 3, 1, 3, 1) + assert np.expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3) + + def test_axis_out_of_range(self): s = (2, 3, 4, 5) a = np.empty(s) - with warnings.catch_warnings(): - warnings.simplefilter("always") - assert_warns(DeprecationWarning, expand_dims, a, -6) - assert_warns(DeprecationWarning, expand_dims, a, 5) + assert_raises(np.AxisError, expand_dims, a, -6) + assert_raises(np.AxisError, expand_dims, a, 5) + + a = np.empty((3, 3, 3)) + assert_raises(np.AxisError, expand_dims, a, (0, -6)) + assert_raises(np.AxisError, expand_dims, a, (0, 5)) + + def test_repeated_axis(self): + a = np.empty((3, 3, 3)) + assert_raises(ValueError, expand_dims, a, axis=(1, 1)) def test_subclasses(self): a = np.arange(10).reshape((2, 5)) |