summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorseberg <sebastian@sipsolutions.net>2014-01-27 09:34:16 -0800
committerseberg <sebastian@sipsolutions.net>2014-01-27 09:34:16 -0800
commit05b236430617dfc6003f60b513c072870d577b2a (patch)
tree0087dbf1c4418880b4a802033d38b252094f3578 /numpy
parent0f7dffa43a06e936cd13910128f9aa5da9c7a105 (diff)
parent3d8da08211406bafb6cf2b95794a3911447b22ba (diff)
downloadnumpy-05b236430617dfc6003f60b513c072870d577b2a.tar.gz
Merge pull request #4234 from jennystone/Branch1
BUG: removed inconsistencies of count (Issues #3368 and #4228)
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py9
-rw-r--r--numpy/ma/tests/test_core.py4
2 files changed, 6 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 8dc2ca86e..42787e3c7 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -3976,21 +3976,16 @@ class MaskedArray(ndarray):
"""
m = self._mask
s = self.shape
- ls = len(s)
if m is nomask:
- if ls == 0:
- return 1
- if ls == 1:
- return s[0]
if axis is None:
return self.size
else:
n = s[axis]
t = list(s)
del t[axis]
- return np.ones(t) * n
+ return np.full(t, n, dtype=np.intp)
n1 = np.size(m, axis)
- n2 = m.astype(int).sum(axis)
+ n2 = np.sum(m, axis=axis, dtype=np.intp)
if axis is None:
return (n1 - n2)
else:
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 764915236..8d8e1c947 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -796,6 +796,7 @@ class TestMaskedArrayArithmetic(TestCase):
def test_count_func(self):
# Tests count
ott = array([0., 1., 2., 3.], mask=[1, 0, 0, 0])
+ ott1= array([0., 1., 2., 3.])
if sys.version_info[0] >= 3:
self.assertTrue(isinstance(count(ott), np.integer))
else:
@@ -812,6 +813,9 @@ class TestMaskedArrayArithmetic(TestCase):
assert_equal(3, count(ott))
assert_(getmask(count(ott, 0)) is nomask)
assert_equal([1, 2], count(ott, 0))
+ assert_equal(type(count(ott, 0)), type(count(ott1, 0)))
+ assert_equal(count(ott, 0).dtype, count(ott1, 0).dtype)
+ assert_raises(IndexError, ott1.count, 1)
def test_minmax_func(self):
# Tests minimum and maximum.