diff options
author | seberg <sebastian@sipsolutions.net> | 2014-01-27 09:34:16 -0800 |
---|---|---|
committer | seberg <sebastian@sipsolutions.net> | 2014-01-27 09:34:16 -0800 |
commit | 05b236430617dfc6003f60b513c072870d577b2a (patch) | |
tree | 0087dbf1c4418880b4a802033d38b252094f3578 /numpy | |
parent | 0f7dffa43a06e936cd13910128f9aa5da9c7a105 (diff) | |
parent | 3d8da08211406bafb6cf2b95794a3911447b22ba (diff) | |
download | numpy-05b236430617dfc6003f60b513c072870d577b2a.tar.gz |
Merge pull request #4234 from jennystone/Branch1
BUG: removed inconsistencies of count (Issues #3368 and #4228)
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 9 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 4 |
2 files changed, 6 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 8dc2ca86e..42787e3c7 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -3976,21 +3976,16 @@ class MaskedArray(ndarray): """ m = self._mask s = self.shape - ls = len(s) if m is nomask: - if ls == 0: - return 1 - if ls == 1: - return s[0] if axis is None: return self.size else: n = s[axis] t = list(s) del t[axis] - return np.ones(t) * n + return np.full(t, n, dtype=np.intp) n1 = np.size(m, axis) - n2 = m.astype(int).sum(axis) + n2 = np.sum(m, axis=axis, dtype=np.intp) if axis is None: return (n1 - n2) else: diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 764915236..8d8e1c947 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -796,6 +796,7 @@ class TestMaskedArrayArithmetic(TestCase): def test_count_func(self): # Tests count ott = array([0., 1., 2., 3.], mask=[1, 0, 0, 0]) + ott1= array([0., 1., 2., 3.]) if sys.version_info[0] >= 3: self.assertTrue(isinstance(count(ott), np.integer)) else: @@ -812,6 +813,9 @@ class TestMaskedArrayArithmetic(TestCase): assert_equal(3, count(ott)) assert_(getmask(count(ott, 0)) is nomask) assert_equal([1, 2], count(ott, 0)) + assert_equal(type(count(ott, 0)), type(count(ott1, 0))) + assert_equal(count(ott, 0).dtype, count(ott1, 0).dtype) + assert_raises(IndexError, ott1.count, 1) def test_minmax_func(self): # Tests minimum and maximum. |