diff options
author | Michael Seifert <michaelseifert04@yahoo.de> | 2016-09-05 16:53:59 +0200 |
---|---|---|
committer | Michael Seifert <michaelseifert04@yahoo.de> | 2016-09-05 16:53:59 +0200 |
commit | 4bcae47040aeef1127cc0d056f950f00d3d9e197 (patch) | |
tree | 48bdf1d021abe542db784b80fedea236fd96b5be /numpy | |
parent | adc155e12648256eea754d1d53e8322e3ac19549 (diff) | |
download | numpy-4bcae47040aeef1127cc0d056f950f00d3d9e197.tar.gz |
BUG: Fixes return for np.ma.count if keepdims is True and axis is None
The returned value is wrapped in an integer array of appropriate dimensions
if keepdims is True and axis is None for np.ma.count.
Also included:
- Whitespace after "," (PEP8)
- any instead of np.any when checking if any axis is out of bounds (performance)
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 12 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 1 |
2 files changed, 11 insertions, 2 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 1bf41b3d8..f83e2adcc 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -26,6 +26,11 @@ import sys import warnings from functools import reduce +if sys.version_info[0] >= 3: + import builtins +else: + import __builtin__ as builtins + import numpy as np import numpy.core.umath as umath import numpy.core.numerictypes as ntypes @@ -4356,13 +4361,15 @@ class MaskedArray(ndarray): raise ValueError("'axis' entry is out of bounds") return 1 elif axis is None: + if kwargs.get('keepdims', False): + return np.array(self.size, dtype=np.intp, ndmin=self.ndim) return self.size axes = axis if isinstance(axis, tuple) else (axis,) axes = tuple(a if a >= 0 else self.ndim + a for a in axes) if len(axes) != len(set(axes)): raise ValueError("duplicate value in 'axis'") - if np.any([a < 0 or a >= self.ndim for a in axes]): + if builtins.any(a < 0 or a >= self.ndim for a in axes): raise ValueError("'axis' entry is out of bounds") items = 1 for ax in axes: @@ -4373,7 +4380,8 @@ class MaskedArray(ndarray): for a in axes: out_dims[a] = 1 else: - out_dims = [d for n,d in enumerate(self.shape) if n not in axes] + out_dims = [d for n, d in enumerate(self.shape) + if n not in axes] # make sure to return a 0-d array if axis is supplied return np.full(out_dims, items, dtype=np.intp) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 7cac90628..338a6d0dc 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -4364,6 +4364,7 @@ class TestOptionalArgs(TestCase): assert_equal(count(a, axis=1), 3*ones((2,4))) assert_equal(count(a, axis=(0,1)), 6*ones((4,))) assert_equal(count(a, keepdims=True), 24*ones((1,1,1))) + assert_equal(np.ndim(count(a, keepdims=True)), 3) assert_equal(count(a, axis=1, keepdims=True), 3*ones((2,1,4))) assert_equal(count(a, axis=(0,1), keepdims=True), 6*ones((1,1,4))) assert_equal(count(a, axis=-2), 3*ones((2,4))) |