diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 18 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 15 |
2 files changed, 29 insertions, 4 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index be07a12e3..de7485638 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -3584,8 +3584,13 @@ class MaskedArray(ndarray): return masked omask = getattr(other, '_mask', nomask) if omask is nomask: - check = ndarray.__eq__(self.filled(0), other).view(type(self)) - check._mask = self._mask + check = ndarray.__eq__(self.filled(0), other) + try: + check = check.view(type(self)) + check._mask = self._mask + except AttributeError: + # Dang, we have a bool instead of an array: return the bool + return check else: odata = filled(other, 0) check = ndarray.__eq__(self.filled(0), odata).view(type(self)) @@ -3612,8 +3617,13 @@ class MaskedArray(ndarray): return masked omask = getattr(other, '_mask', nomask) if omask is nomask: - check = ndarray.__ne__(self.filled(0), other).view(type(self)) - check._mask = self._mask + check = ndarray.__ne__(self.filled(0), other) + try: + check = check.view(type(self)) + check._mask = self._mask + except AttributeError: + # In case check is a boolean (or a numpy.bool) + return check else: odata = filled(other, 0) check = ndarray.__ne__(self.filled(0), odata).view(type(self)) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index f95d621db..908d7adc6 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1149,6 +1149,21 @@ class TestMaskedArrayArithmetic(TestCase): assert_equal(test.mask, [False, False]) + def test_eq_w_None(self): + a = array([1, 2], mask=False) + assert_equal(a == None, False) + assert_equal(a != None, True) + a = masked + assert_equal(a == None, masked) + + def test_eq_w_scalar(self): + a = array(1) + assert_equal(a == 1, True) + assert_equal(a == 0, False) + assert_equal(a != 1, False) + assert_equal(a != 0, True) + + def test_numpyarithmetics(self): "Check that the mask is not back-propagated when using numpy functions" a = masked_array([-1, 0, 1, 2, 3], mask=[0, 0, 0, 0, 1]) |