diff options
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r-- | numpy/ma/core.py | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 1bf41b3d8..f83e2adcc 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -26,6 +26,11 @@ import sys import warnings from functools import reduce +if sys.version_info[0] >= 3: + import builtins +else: + import __builtin__ as builtins + import numpy as np import numpy.core.umath as umath import numpy.core.numerictypes as ntypes @@ -4356,13 +4361,15 @@ class MaskedArray(ndarray): raise ValueError("'axis' entry is out of bounds") return 1 elif axis is None: + if kwargs.get('keepdims', False): + return np.array(self.size, dtype=np.intp, ndmin=self.ndim) return self.size axes = axis if isinstance(axis, tuple) else (axis,) axes = tuple(a if a >= 0 else self.ndim + a for a in axes) if len(axes) != len(set(axes)): raise ValueError("duplicate value in 'axis'") - if np.any([a < 0 or a >= self.ndim for a in axes]): + if builtins.any(a < 0 or a >= self.ndim for a in axes): raise ValueError("'axis' entry is out of bounds") items = 1 for ax in axes: @@ -4373,7 +4380,8 @@ class MaskedArray(ndarray): for a in axes: out_dims[a] = 1 else: - out_dims = [d for n,d in enumerate(self.shape) if n not in axes] + out_dims = [d for n, d in enumerate(self.shape) + if n not in axes] # make sure to return a 0-d array if axis is supplied return np.full(out_dims, items, dtype=np.intp) |