diff options
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 81d8d9d17..588047952 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -7,6 +7,7 @@ from numpy.core.numeric import ( asarray, zeros, outer, concatenate, isscalar, array, asanyarray ) from numpy.core.fromnumeric import product, reshape, transpose +from numpy.core.multiarray import normalize_axis_index from numpy.core import vstack, atleast_3d from numpy.lib.index_tricks import ndindex from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells @@ -96,10 +97,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): # handle negative axes arr = asanyarray(arr) nd = arr.ndim - if not (-nd <= axis < nd): - raise IndexError('axis {0} out of bounds [-{1}, {1})'.format(axis, nd)) - if axis < 0: - axis += nd + axis = normalize_axis_index(axis, nd) # arr, with the iteration axis at the end in_dims = list(range(nd)) @@ -259,6 +257,8 @@ def expand_dims(a, axis): See Also -------- + squeeze : The inverse operation, removing singleton dimensions + reshape : Insert, remove, and combine dimensions, and resize existing ones doc.indexing, atleast_1d, atleast_2d, atleast_3d Examples @@ -291,8 +291,7 @@ def expand_dims(a, axis): """ a = asarray(a) shape = a.shape - if axis < 0: - axis = axis + len(shape) + 1 + axis = normalize_axis_index(axis, a.ndim + 1) return a.reshape(shape[:axis] + (1,) + shape[axis:]) row_stack = vstack @@ -372,7 +371,8 @@ def dstack(tup): Notes ----- - Equivalent to ``np.concatenate(tup, axis=2)``. + Equivalent to ``np.concatenate(tup, axis=2)`` if `tup` contains arrays that + are at least 3-dimensional. Examples -------- @@ -395,7 +395,7 @@ def dstack(tup): def _replace_zero_by_x_arrays(sub_arys): for i in range(len(sub_arys)): - if len(_nx.shape(sub_arys[i])) == 0: + if _nx.ndim(sub_arys[i]) == 0: sub_arys[i] = _nx.empty(0, dtype=sub_arys[i].dtype) elif _nx.sometrue(_nx.equal(_nx.shape(sub_arys[i]), 0)): sub_arys[i] = _nx.empty(0, dtype=sub_arys[i].dtype) @@ -582,9 +582,9 @@ def hsplit(ary, indices_or_sections): [[ 6., 7.]]])] """ - if len(_nx.shape(ary)) == 0: + if _nx.ndim(ary) == 0: raise ValueError('hsplit only works on arrays of 1 or more dimensions') - if len(ary.shape) > 1: + if ary.ndim > 1: return split(ary, indices_or_sections, 1) else: return split(ary, indices_or_sections, 0) @@ -636,7 +636,7 @@ def vsplit(ary, indices_or_sections): [ 6., 7.]]])] """ - if len(_nx.shape(ary)) < 2: + if _nx.ndim(ary) < 2: raise ValueError('vsplit only works on arrays of 2 or more dimensions') return split(ary, indices_or_sections, 0) @@ -681,7 +681,7 @@ def dsplit(ary, indices_or_sections): array([], dtype=float64)] """ - if len(_nx.shape(ary)) < 3: + if _nx.ndim(ary) < 3: raise ValueError('dsplit only works on arrays of 3 or more dimensions') return split(ary, indices_or_sections, 2) |