diff options
author | Janani <jenny.stone125@gmail.com> | 2014-01-26 16:18:34 +0530 |
---|---|---|
committer | Janani <jenny.stone125@gmail.com> | 2014-01-27 18:10:08 +0530 |
commit | 3d8da08211406bafb6cf2b95794a3911447b22ba (patch) | |
tree | a0131fd1af11ac30cfdd058f2bcbc0140aaaca0e | |
parent | 2fc0c65c4f465762f5f479d173205c9158529a83 (diff) | |
download | numpy-3d8da08211406bafb6cf2b95794a3911447b22ba.tar.gz |
BUG: Removed the inconsistencies of the function ma.count
The inconsistency of ma.count when appled on masked_array with mask as
'nomask' has been removed. the return type of the function has also been
standardized according to the docs.Corresponding changes in testing to
check the functioning of ma.count
Closes gh-3368 and gh-4228
-rw-r--r-- | numpy/ma/core.py | 9 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 4 |
2 files changed, 6 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 8dc2ca86e..42787e3c7 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -3976,21 +3976,16 @@ class MaskedArray(ndarray): """ m = self._mask s = self.shape - ls = len(s) if m is nomask: - if ls == 0: - return 1 - if ls == 1: - return s[0] if axis is None: return self.size else: n = s[axis] t = list(s) del t[axis] - return np.ones(t) * n + return np.full(t, n, dtype=np.intp) n1 = np.size(m, axis) - n2 = m.astype(int).sum(axis) + n2 = np.sum(m, axis=axis, dtype=np.intp) if axis is None: return (n1 - n2) else: diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 764915236..8d8e1c947 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -796,6 +796,7 @@ class TestMaskedArrayArithmetic(TestCase): def test_count_func(self): # Tests count ott = array([0., 1., 2., 3.], mask=[1, 0, 0, 0]) + ott1= array([0., 1., 2., 3.]) if sys.version_info[0] >= 3: self.assertTrue(isinstance(count(ott), np.integer)) else: @@ -812,6 +813,9 @@ class TestMaskedArrayArithmetic(TestCase): assert_equal(3, count(ott)) assert_(getmask(count(ott, 0)) is nomask) assert_equal([1, 2], count(ott, 0)) + assert_equal(type(count(ott, 0)), type(count(ott1, 0))) + assert_equal(count(ott, 0).dtype, count(ott1, 0).dtype) + assert_raises(IndexError, ott1.count, 1) def test_minmax_func(self): # Tests minimum and maximum. |