summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-09-05 11:40:43 -0500
committerGitHub <noreply@github.com>2016-09-05 11:40:43 -0500
commit175cc57245a78e4037f80835e63dbbe077d43cae (patch)
tree48bdf1d021abe542db784b80fedea236fd96b5be
parentadc155e12648256eea754d1d53e8322e3ac19549 (diff)
parent4bcae47040aeef1127cc0d056f950f00d3d9e197 (diff)
downloadnumpy-175cc57245a78e4037f80835e63dbbe077d43cae.tar.gz
Merge pull request #8018 from MSeifert04/bugfix_ma_count_keepdims
BUG: Fixes return for np.ma.count if keepdims is True and axis is None
-rw-r--r--numpy/ma/core.py12
-rw-r--r--numpy/ma/tests/test_core.py1
2 files changed, 11 insertions, 2 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 1bf41b3d8..f83e2adcc 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -26,6 +26,11 @@ import sys
import warnings
from functools import reduce
+if sys.version_info[0] >= 3:
+ import builtins
+else:
+ import __builtin__ as builtins
+
import numpy as np
import numpy.core.umath as umath
import numpy.core.numerictypes as ntypes
@@ -4356,13 +4361,15 @@ class MaskedArray(ndarray):
raise ValueError("'axis' entry is out of bounds")
return 1
elif axis is None:
+ if kwargs.get('keepdims', False):
+ return np.array(self.size, dtype=np.intp, ndmin=self.ndim)
return self.size
axes = axis if isinstance(axis, tuple) else (axis,)
axes = tuple(a if a >= 0 else self.ndim + a for a in axes)
if len(axes) != len(set(axes)):
raise ValueError("duplicate value in 'axis'")
- if np.any([a < 0 or a >= self.ndim for a in axes]):
+ if builtins.any(a < 0 or a >= self.ndim for a in axes):
raise ValueError("'axis' entry is out of bounds")
items = 1
for ax in axes:
@@ -4373,7 +4380,8 @@ class MaskedArray(ndarray):
for a in axes:
out_dims[a] = 1
else:
- out_dims = [d for n,d in enumerate(self.shape) if n not in axes]
+ out_dims = [d for n, d in enumerate(self.shape)
+ if n not in axes]
# make sure to return a 0-d array if axis is supplied
return np.full(out_dims, items, dtype=np.intp)
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 7cac90628..338a6d0dc 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -4364,6 +4364,7 @@ class TestOptionalArgs(TestCase):
assert_equal(count(a, axis=1), 3*ones((2,4)))
assert_equal(count(a, axis=(0,1)), 6*ones((4,)))
assert_equal(count(a, keepdims=True), 24*ones((1,1,1)))
+ assert_equal(np.ndim(count(a, keepdims=True)), 3)
assert_equal(count(a, axis=1, keepdims=True), 3*ones((2,1,4)))
assert_equal(count(a, axis=(0,1), keepdims=True), 6*ones((1,1,4)))
assert_equal(count(a, axis=-2), 3*ones((2,4)))