diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 6 | ||||
-rw-r--r-- | numpy/ma/extras.py | 16 |
2 files changed, 7 insertions, 15 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 9386ff3be..6bc1bc623 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -43,6 +43,7 @@ from numpy.compat import ( ) from numpy import expand_dims as n_expand_dims from numpy.core.multiarray import normalize_axis_index +from numpy.core.numeric import _validate_axis if sys.version_info[0] >= 3: @@ -4369,10 +4370,7 @@ class MaskedArray(ndarray): return np.array(self.size, dtype=np.intp, ndmin=self.ndim) return self.size - axes = axis if isinstance(axis, tuple) else (axis,) - axes = tuple(normalize_axis_index(a, self.ndim) for a in axes) - if len(axes) != len(set(axes)): - raise ValueError("duplicate value in 'axis'") + axes = _validate_axis(axis, self.ndim) items = 1 for ax in axes: items *= self.shape[ax] diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 697565251..8b37df902 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -37,6 +37,7 @@ import numpy as np from numpy import ndarray, array as nxarray import numpy.core.umath as umath from numpy.core.multiarray import normalize_axis_index +from numpy.core.numeric import _validate_axis from numpy.lib.function_base import _ureduce from numpy.lib.index_tricks import AxisConcatenator @@ -816,18 +817,11 @@ def compress_nd(x, axis=None): x = asarray(x) m = getmask(x) # Set axis to tuple of ints - if isinstance(axis, int): - axis = (axis,) - elif axis is None: + if axis is None: axis = tuple(range(x.ndim)) - elif not isinstance(axis, tuple): - raise ValueError('Invalid type for axis argument') - # Check axis input - axis = [ax + x.ndim if ax < 0 else ax for ax in axis] - if not all(0 <= ax < x.ndim for ax in axis): - raise ValueError("'axis' entry is out of bounds") - if len(axis) != len(set(axis)): - raise ValueError("duplicate value in 'axis'") + else: + axis = _validate_axis(axis, x.ndim) + # Nothing is masked: return x if m is nomask or not m.any(): return x._data |