summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMichael Seifert <michaelseifert04@yahoo.de>2016-09-05 16:53:59 +0200
committerMichael Seifert <michaelseifert04@yahoo.de>2016-09-05 16:53:59 +0200
commit4bcae47040aeef1127cc0d056f950f00d3d9e197 (patch)
tree48bdf1d021abe542db784b80fedea236fd96b5be /numpy
parentadc155e12648256eea754d1d53e8322e3ac19549 (diff)
downloadnumpy-4bcae47040aeef1127cc0d056f950f00d3d9e197.tar.gz
BUG: Fixes return for np.ma.count if keepdims is True and axis is None
The returned value is wrapped in an integer array of appropriate dimensions if keepdims is True and axis is None for np.ma.count. Also included: - Whitespace after "," (PEP8) - any instead of np.any when checking if any axis is out of bounds (performance)
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py12
-rw-r--r--numpy/ma/tests/test_core.py1
2 files changed, 11 insertions, 2 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 1bf41b3d8..f83e2adcc 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -26,6 +26,11 @@ import sys
import warnings
from functools import reduce
+if sys.version_info[0] >= 3:
+ import builtins
+else:
+ import __builtin__ as builtins
+
import numpy as np
import numpy.core.umath as umath
import numpy.core.numerictypes as ntypes
@@ -4356,13 +4361,15 @@ class MaskedArray(ndarray):
raise ValueError("'axis' entry is out of bounds")
return 1
elif axis is None:
+ if kwargs.get('keepdims', False):
+ return np.array(self.size, dtype=np.intp, ndmin=self.ndim)
return self.size
axes = axis if isinstance(axis, tuple) else (axis,)
axes = tuple(a if a >= 0 else self.ndim + a for a in axes)
if len(axes) != len(set(axes)):
raise ValueError("duplicate value in 'axis'")
- if np.any([a < 0 or a >= self.ndim for a in axes]):
+ if builtins.any(a < 0 or a >= self.ndim for a in axes):
raise ValueError("'axis' entry is out of bounds")
items = 1
for ax in axes:
@@ -4373,7 +4380,8 @@ class MaskedArray(ndarray):
for a in axes:
out_dims[a] = 1
else:
- out_dims = [d for n,d in enumerate(self.shape) if n not in axes]
+ out_dims = [d for n, d in enumerate(self.shape)
+ if n not in axes]
# make sure to return a 0-d array if axis is supplied
return np.full(out_dims, items, dtype=np.intp)
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 7cac90628..338a6d0dc 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -4364,6 +4364,7 @@ class TestOptionalArgs(TestCase):
assert_equal(count(a, axis=1), 3*ones((2,4)))
assert_equal(count(a, axis=(0,1)), 6*ones((4,)))
assert_equal(count(a, keepdims=True), 24*ones((1,1,1)))
+ assert_equal(np.ndim(count(a, keepdims=True)), 3)
assert_equal(count(a, axis=1, keepdims=True), 3*ones((2,1,4)))
assert_equal(count(a, axis=(0,1), keepdims=True), 6*ones((1,1,4)))
assert_equal(count(a, axis=-2), 3*ones((2,4)))