summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/shape_base.py23
-rw-r--r--numpy/lib/tests/test_shape_base.py23
2 files changed, 41 insertions, 5 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 830943e72..ea77f40e0 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -240,14 +240,20 @@ def expand_dims(a, axis):
"""
Expand the shape of an array.
- Insert a new axis, corresponding to a given position in the array shape.
+ Insert a new axis that will appear at the `axis` position in the expanded
+ array shape.
+
+ .. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor
+ ``axis > a.ndim`` raised errors or put the new axis where documented.
+ Those axis values are now deprecated and will raise an AxisError in the
+ future.
Parameters
----------
a : array_like
Input array.
axis : int
- Position (amongst axes) where new axis is to be inserted.
+ Position in the expanded axes where the new axis is placed.
Returns
-------
@@ -291,7 +297,16 @@ def expand_dims(a, axis):
"""
a = asarray(a)
shape = a.shape
- axis = normalize_axis_index(axis, a.ndim + 1)
+ if axis > a.ndim or axis < -a.ndim - 1:
+ # 2017-05-17, 1.13.0
+ warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are "
+ "deprecated and will raise an AxisError in the future.",
+ DeprecationWarning, stacklevel=2)
+ # When the deprecation period expires, delete this if block,
+ if axis < 0:
+ axis = axis + a.ndim + 1
+ # and uncomment the following line.
+ # axis = normalize_axis_index(axis, a.ndim + 1)
return a.reshape(shape[:axis] + (1,) + shape[axis:])
row_stack = vstack
@@ -317,7 +332,7 @@ def column_stack(tup):
See Also
--------
- hstack, vstack, concatenate
+ stack, hstack, vstack, concatenate
Examples
--------
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 4d06001f4..14406fe21 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -1,9 +1,11 @@
from __future__ import division, absolute_import, print_function
import numpy as np
+import warnings
+
from numpy.lib.shape_base import (
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
- vsplit, dstack, column_stack, kron, tile
+ vsplit, dstack, column_stack, kron, tile, expand_dims,
)
from numpy.testing import (
run_module_suite, TestCase, assert_, assert_equal, assert_array_equal,
@@ -182,6 +184,25 @@ class TestApplyOverAxes(TestCase):
assert_array_equal(aoa_a, np.array([[[60], [92], [124]]]))
+class TestExpandDims(TestCase):
+ def test_functionality(self):
+ s = (2, 3, 4, 5)
+ a = np.empty(s)
+ for axis in range(-5, 4):
+ b = expand_dims(a, axis)
+ assert_(b.shape[axis] == 1)
+ assert_(np.squeeze(b).shape == s)
+
+ def test_deprecations(self):
+ # 2017-05-17, 1.13.0
+ s = (2, 3, 4, 5)
+ a = np.empty(s)
+ with warnings.catch_warnings():
+ warnings.simplefilter("always")
+ assert_warns(DeprecationWarning, expand_dims, a, -6)
+ assert_warns(DeprecationWarning, expand_dims, a, 5)
+
+
class TestArraySplit(TestCase):
def test_integer_0_split(self):
a = np.arange(10)