summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py6
-rw-r--r--numpy/ma/extras.py16
2 files changed, 7 insertions, 15 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 9386ff3be..6bc1bc623 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -43,6 +43,7 @@ from numpy.compat import (
)
from numpy import expand_dims as n_expand_dims
from numpy.core.multiarray import normalize_axis_index
+from numpy.core.numeric import _validate_axis
if sys.version_info[0] >= 3:
@@ -4369,10 +4370,7 @@ class MaskedArray(ndarray):
return np.array(self.size, dtype=np.intp, ndmin=self.ndim)
return self.size
- axes = axis if isinstance(axis, tuple) else (axis,)
- axes = tuple(normalize_axis_index(a, self.ndim) for a in axes)
- if len(axes) != len(set(axes)):
- raise ValueError("duplicate value in 'axis'")
+ axes = _validate_axis(axis, self.ndim)
items = 1
for ax in axes:
items *= self.shape[ax]
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 697565251..8b37df902 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -37,6 +37,7 @@ import numpy as np
from numpy import ndarray, array as nxarray
import numpy.core.umath as umath
from numpy.core.multiarray import normalize_axis_index
+from numpy.core.numeric import _validate_axis
from numpy.lib.function_base import _ureduce
from numpy.lib.index_tricks import AxisConcatenator
@@ -816,18 +817,11 @@ def compress_nd(x, axis=None):
x = asarray(x)
m = getmask(x)
# Set axis to tuple of ints
- if isinstance(axis, int):
- axis = (axis,)
- elif axis is None:
+ if axis is None:
axis = tuple(range(x.ndim))
- elif not isinstance(axis, tuple):
- raise ValueError('Invalid type for axis argument')
- # Check axis input
- axis = [ax + x.ndim if ax < 0 else ax for ax in axis]
- if not all(0 <= ax < x.ndim for ax in axis):
- raise ValueError("'axis' entry is out of bounds")
- if len(axis) != len(set(axis)):
- raise ValueError("duplicate value in 'axis'")
+ else:
+ axis = _validate_axis(axis, x.ndim)
+
# Nothing is masked: return x
if m is nomask or not m.any():
return x._data