diff options
author | Larry Bradley <larry.bradley@gmail.com> | 2019-12-02 17:06:42 -0500 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2019-12-02 16:06:42 -0600 |
commit | 03d489735e863e27f3e6ce39b8a85eca440c0231 (patch) | |
tree | 89bda3ea557fa7db40c0b5f9c0543e2ab93151e5 /numpy/lib/tests/test_shape_base.py | |
parent | 5992098524c9f36288093ef3298d44343735842e (diff) | |
download | numpy-03d489735e863e27f3e6ce39b8a85eca440c0231.tar.gz |
ENH,DEP: Allow multiple axes in expand_dims (#14051)
This PR allows the axis keyword in expand_dims to be a tuple of ints. Previously, axis could only be an int.
This issue was previously discussed in gh-12290 and the changes are based on gh-12290 (comment).
This PR also removes the deprecation added in v1.13 (2017-05-17), where previously axis could be outside of the range (-a.ndim - 1) <= axis <= a.ndim. Such an axis value will now raise an AxisError. Please let me know if it's too soon to remove this deprecation (I could not find any dev docs stating the length of the numpy deprecation cycle).
Closes gh-12290.
Diffstat (limited to 'numpy/lib/tests/test_shape_base.py')
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 01ea028bb..be1604a75 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -289,14 +289,26 @@ class TestExpandDims(object): assert_(b.shape[axis] == 1) assert_(np.squeeze(b).shape == s) - def test_deprecations(self): - # 2017-05-17, 1.13.0 + def test_axis_tuple(self): + a = np.empty((3, 3, 3)) + assert np.expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3) + assert np.expand_dims(a, axis=(0, -1, -2)).shape == (1, 3, 3, 3, 1, 1) + assert np.expand_dims(a, axis=(0, 3, 5)).shape == (1, 3, 3, 1, 3, 1) + assert np.expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3) + + def test_axis_out_of_range(self): s = (2, 3, 4, 5) a = np.empty(s) - with warnings.catch_warnings(): - warnings.simplefilter("always") - assert_warns(DeprecationWarning, expand_dims, a, -6) - assert_warns(DeprecationWarning, expand_dims, a, 5) + assert_raises(np.AxisError, expand_dims, a, -6) + assert_raises(np.AxisError, expand_dims, a, 5) + + a = np.empty((3, 3, 3)) + assert_raises(np.AxisError, expand_dims, a, (0, -6)) + assert_raises(np.AxisError, expand_dims, a, (0, 5)) + + def test_repeated_axis(self): + a = np.empty((3, 3, 3)) + assert_raises(ValueError, expand_dims, a, axis=(1, 1)) def test_subclasses(self): a = np.arange(10).reshape((2, 5)) |