summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJanani <jenny.stone125@gmail.com>2014-01-26 16:18:34 +0530
committerJanani <jenny.stone125@gmail.com>2014-01-27 18:10:08 +0530
commit3d8da08211406bafb6cf2b95794a3911447b22ba (patch)
treea0131fd1af11ac30cfdd058f2bcbc0140aaaca0e
parent2fc0c65c4f465762f5f479d173205c9158529a83 (diff)
downloadnumpy-3d8da08211406bafb6cf2b95794a3911447b22ba.tar.gz
BUG: Removed the inconsistencies of the function ma.count
The inconsistency of ma.count when appled on masked_array with mask as 'nomask' has been removed. the return type of the function has also been standardized according to the docs.Corresponding changes in testing to check the functioning of ma.count Closes gh-3368 and gh-4228
-rw-r--r--numpy/ma/core.py9
-rw-r--r--numpy/ma/tests/test_core.py4
2 files changed, 6 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 8dc2ca86e..42787e3c7 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -3976,21 +3976,16 @@ class MaskedArray(ndarray):
"""
m = self._mask
s = self.shape
- ls = len(s)
if m is nomask:
- if ls == 0:
- return 1
- if ls == 1:
- return s[0]
if axis is None:
return self.size
else:
n = s[axis]
t = list(s)
del t[axis]
- return np.ones(t) * n
+ return np.full(t, n, dtype=np.intp)
n1 = np.size(m, axis)
- n2 = m.astype(int).sum(axis)
+ n2 = np.sum(m, axis=axis, dtype=np.intp)
if axis is None:
return (n1 - n2)
else:
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 764915236..8d8e1c947 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -796,6 +796,7 @@ class TestMaskedArrayArithmetic(TestCase):
def test_count_func(self):
# Tests count
ott = array([0., 1., 2., 3.], mask=[1, 0, 0, 0])
+ ott1= array([0., 1., 2., 3.])
if sys.version_info[0] >= 3:
self.assertTrue(isinstance(count(ott), np.integer))
else:
@@ -812,6 +813,9 @@ class TestMaskedArrayArithmetic(TestCase):
assert_equal(3, count(ott))
assert_(getmask(count(ott, 0)) is nomask)
assert_equal([1, 2], count(ott, 0))
+ assert_equal(type(count(ott, 0)), type(count(ott1, 0)))
+ assert_equal(count(ott, 0).dtype, count(ott1, 0).dtype)
+ assert_raises(IndexError, ott1.count, 1)
def test_minmax_func(self):
# Tests minimum and maximum.