diff options
Diffstat (limited to 'numpy/ma')
-rw-r--r-- | numpy/ma/core.py | 9 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 4 |
2 files changed, 6 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 8dc2ca86e..42787e3c7 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -3976,21 +3976,16 @@ class MaskedArray(ndarray): """ m = self._mask s = self.shape - ls = len(s) if m is nomask: - if ls == 0: - return 1 - if ls == 1: - return s[0] if axis is None: return self.size else: n = s[axis] t = list(s) del t[axis] - return np.ones(t) * n + return np.full(t, n, dtype=np.intp) n1 = np.size(m, axis) - n2 = m.astype(int).sum(axis) + n2 = np.sum(m, axis=axis, dtype=np.intp) if axis is None: return (n1 - n2) else: diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 764915236..8d8e1c947 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -796,6 +796,7 @@ class TestMaskedArrayArithmetic(TestCase): def test_count_func(self): # Tests count ott = array([0., 1., 2., 3.], mask=[1, 0, 0, 0]) + ott1= array([0., 1., 2., 3.]) if sys.version_info[0] >= 3: self.assertTrue(isinstance(count(ott), np.integer)) else: @@ -812,6 +813,9 @@ class TestMaskedArrayArithmetic(TestCase): assert_equal(3, count(ott)) assert_(getmask(count(ott, 0)) is nomask) assert_equal([1, 2], count(ott, 0)) + assert_equal(type(count(ott, 0)), type(count(ott1, 0))) + assert_equal(count(ott, 0).dtype, count(ott1, 0).dtype) + assert_raises(IndexError, ott1.count, 1) def test_minmax_func(self): # Tests minimum and maximum. |