summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r--numpy/lib/shape_base.py24
1 files changed, 12 insertions, 12 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 81d8d9d17..588047952 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -7,6 +7,7 @@ from numpy.core.numeric import (
asarray, zeros, outer, concatenate, isscalar, array, asanyarray
)
from numpy.core.fromnumeric import product, reshape, transpose
+from numpy.core.multiarray import normalize_axis_index
from numpy.core import vstack, atleast_3d
from numpy.lib.index_tricks import ndindex
from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells
@@ -96,10 +97,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
# handle negative axes
arr = asanyarray(arr)
nd = arr.ndim
- if not (-nd <= axis < nd):
- raise IndexError('axis {0} out of bounds [-{1}, {1})'.format(axis, nd))
- if axis < 0:
- axis += nd
+ axis = normalize_axis_index(axis, nd)
# arr, with the iteration axis at the end
in_dims = list(range(nd))
@@ -259,6 +257,8 @@ def expand_dims(a, axis):
See Also
--------
+ squeeze : The inverse operation, removing singleton dimensions
+ reshape : Insert, remove, and combine dimensions, and resize existing ones
doc.indexing, atleast_1d, atleast_2d, atleast_3d
Examples
@@ -291,8 +291,7 @@ def expand_dims(a, axis):
"""
a = asarray(a)
shape = a.shape
- if axis < 0:
- axis = axis + len(shape) + 1
+ axis = normalize_axis_index(axis, a.ndim + 1)
return a.reshape(shape[:axis] + (1,) + shape[axis:])
row_stack = vstack
@@ -372,7 +371,8 @@ def dstack(tup):
Notes
-----
- Equivalent to ``np.concatenate(tup, axis=2)``.
+ Equivalent to ``np.concatenate(tup, axis=2)`` if `tup` contains arrays that
+ are at least 3-dimensional.
Examples
--------
@@ -395,7 +395,7 @@ def dstack(tup):
def _replace_zero_by_x_arrays(sub_arys):
for i in range(len(sub_arys)):
- if len(_nx.shape(sub_arys[i])) == 0:
+ if _nx.ndim(sub_arys[i]) == 0:
sub_arys[i] = _nx.empty(0, dtype=sub_arys[i].dtype)
elif _nx.sometrue(_nx.equal(_nx.shape(sub_arys[i]), 0)):
sub_arys[i] = _nx.empty(0, dtype=sub_arys[i].dtype)
@@ -582,9 +582,9 @@ def hsplit(ary, indices_or_sections):
[[ 6., 7.]]])]
"""
- if len(_nx.shape(ary)) == 0:
+ if _nx.ndim(ary) == 0:
raise ValueError('hsplit only works on arrays of 1 or more dimensions')
- if len(ary.shape) > 1:
+ if ary.ndim > 1:
return split(ary, indices_or_sections, 1)
else:
return split(ary, indices_or_sections, 0)
@@ -636,7 +636,7 @@ def vsplit(ary, indices_or_sections):
[ 6., 7.]]])]
"""
- if len(_nx.shape(ary)) < 2:
+ if _nx.ndim(ary) < 2:
raise ValueError('vsplit only works on arrays of 2 or more dimensions')
return split(ary, indices_or_sections, 0)
@@ -681,7 +681,7 @@ def dsplit(ary, indices_or_sections):
array([], dtype=float64)]
"""
- if len(_nx.shape(ary)) < 3:
+ if _nx.ndim(ary) < 3:
raise ValueError('dsplit only works on arrays of 3 or more dimensions')
return split(ary, indices_or_sections, 2)