diff options
-rw-r--r-- | numpy/ma/core.py | 12 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 1 |
2 files changed, 11 insertions, 2 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 1bf41b3d8..f83e2adcc 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -26,6 +26,11 @@ import sys import warnings from functools import reduce +if sys.version_info[0] >= 3: + import builtins +else: + import __builtin__ as builtins + import numpy as np import numpy.core.umath as umath import numpy.core.numerictypes as ntypes @@ -4356,13 +4361,15 @@ class MaskedArray(ndarray): raise ValueError("'axis' entry is out of bounds") return 1 elif axis is None: + if kwargs.get('keepdims', False): + return np.array(self.size, dtype=np.intp, ndmin=self.ndim) return self.size axes = axis if isinstance(axis, tuple) else (axis,) axes = tuple(a if a >= 0 else self.ndim + a for a in axes) if len(axes) != len(set(axes)): raise ValueError("duplicate value in 'axis'") - if np.any([a < 0 or a >= self.ndim for a in axes]): + if builtins.any(a < 0 or a >= self.ndim for a in axes): raise ValueError("'axis' entry is out of bounds") items = 1 for ax in axes: @@ -4373,7 +4380,8 @@ class MaskedArray(ndarray): for a in axes: out_dims[a] = 1 else: - out_dims = [d for n,d in enumerate(self.shape) if n not in axes] + out_dims = [d for n, d in enumerate(self.shape) + if n not in axes] # make sure to return a 0-d array if axis is supplied return np.full(out_dims, items, dtype=np.intp) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 7cac90628..338a6d0dc 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -4364,6 +4364,7 @@ class TestOptionalArgs(TestCase): assert_equal(count(a, axis=1), 3*ones((2,4))) assert_equal(count(a, axis=(0,1)), 6*ones((4,))) assert_equal(count(a, keepdims=True), 24*ones((1,1,1))) + assert_equal(np.ndim(count(a, keepdims=True)), 3) assert_equal(count(a, axis=1, keepdims=True), 3*ones((2,1,4))) assert_equal(count(a, axis=(0,1), keepdims=True), 6*ones((1,1,4))) assert_equal(count(a, axis=-2), 3*ones((2,4))) |