summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r--numpy/lib/shape_base.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 01d13514a..ea77f40e0 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -240,14 +240,20 @@ def expand_dims(a, axis):
"""
Expand the shape of an array.
- Insert a new axis, corresponding to a given position in the array shape.
+ Insert a new axis that will appear at the `axis` position in the expanded
+ array shape.
+
+ .. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor
+ ``axis > a.ndim`` raised errors or put the new axis where documented.
+ Those axis values are now deprecated and will raise an AxisError in the
+ future.
Parameters
----------
a : array_like
Input array.
axis : int
- Position (amongst axes) where new axis is to be inserted.
+ Position in the expanded axes where the new axis is placed.
Returns
-------
@@ -291,7 +297,16 @@ def expand_dims(a, axis):
"""
a = asarray(a)
shape = a.shape
- axis = normalize_axis_index(axis, a.ndim + 1)
+ if axis > a.ndim or axis < -a.ndim - 1:
+ # 2017-05-17, 1.13.0
+ warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are "
+ "deprecated and will raise an AxisError in the future.",
+ DeprecationWarning, stacklevel=2)
+ # When the deprecation period expires, delete this if block,
+ if axis < 0:
+ axis = axis + a.ndim + 1
+ # and uncomment the following line.
+ # axis = normalize_axis_index(axis, a.ndim + 1)
return a.reshape(shape[:axis] + (1,) + shape[axis:])
row_stack = vstack