diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2017-05-17 13:02:01 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2017-05-17 19:03:15 -0600 |
commit | beac50cf98f450539dcdeee0273cfe5175d45d26 (patch) | |
tree | 52ebc11c1a4130592bc227eb385fb513073a7c4a /numpy/lib/tests | |
parent | b9e3ac9abb6e435cdf6bbe33e0bc894d6a879a53 (diff) | |
download | numpy-beac50cf98f450539dcdeee0273cfe5175d45d26.tar.gz |
DEP: Deprecate incorrect behavior of expand_dims.
Expand_dims works as documented when the index of the inserted NewAxis
in the resulting array satisfies -a.ndim - 1 <= index <= a.ndim.
However, when index > a.ndim index is replaced by a.ndim and, when
index < -a.ndim - 1, it is replaced by index + a.ndim + 1, which may be
negative and results in incorrect placement. The latter two cases are
now deprecated.
Closes #9100.
Diffstat (limited to 'numpy/lib/tests')
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 4d06001f4..14406fe21 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -1,9 +1,11 @@ from __future__ import division, absolute_import, print_function import numpy as np +import warnings + from numpy.lib.shape_base import ( apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit, - vsplit, dstack, column_stack, kron, tile + vsplit, dstack, column_stack, kron, tile, expand_dims, ) from numpy.testing import ( run_module_suite, TestCase, assert_, assert_equal, assert_array_equal, @@ -182,6 +184,25 @@ class TestApplyOverAxes(TestCase): assert_array_equal(aoa_a, np.array([[[60], [92], [124]]])) +class TestExpandDims(TestCase): + def test_functionality(self): + s = (2, 3, 4, 5) + a = np.empty(s) + for axis in range(-5, 4): + b = expand_dims(a, axis) + assert_(b.shape[axis] == 1) + assert_(np.squeeze(b).shape == s) + + def test_deprecations(self): + # 2017-05-17, 1.13.0 + s = (2, 3, 4, 5) + a = np.empty(s) + with warnings.catch_warnings(): + warnings.simplefilter("always") + assert_warns(DeprecationWarning, expand_dims, a, -6) + assert_warns(DeprecationWarning, expand_dims, a, 5) + + class TestArraySplit(TestCase): def test_integer_0_split(self): a = np.arange(10) |