summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/ma/core.py20
-rw-r--r--numpy/ma/tests/test_core.py46
2 files changed, 65 insertions, 1 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index d8fd4f389..93eb74be3 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -4102,6 +4102,10 @@ class MaskedArray(ndarray):
odata = getdata(other)
if mask.dtype.names is not None:
+ # only == and != are reasonably defined for structured dtypes,
+ # so give up early for all other comparisons:
+ if compare not in (operator.eq, operator.ne):
+ return NotImplemented
# For possibly masked structured arrays we need to be careful,
# since the standard structured array comparison will use all
# fields, masked or not. To avoid masked fields influencing the
@@ -4124,10 +4128,11 @@ class MaskedArray(ndarray):
if isinstance(check, (np.bool_, bool)):
return masked if mask else check
- if mask is not nomask:
+ if mask is not nomask and compare in (operator.eq, operator.ne):
# Adjust elements that were masked, which should be treated
# as equal if masked in both, unequal if masked in one.
# Note that this works automatically for structured arrays too.
+ # Ignore this for operations other than `==` and `!=`
check = np.where(mask, compare(smask, omask), check)
if mask.shape != check.shape:
# Guarantee consistency of the shape, making a copy since the
@@ -4175,6 +4180,19 @@ class MaskedArray(ndarray):
"""
return self._comparison(other, operator.ne)
+ # All other comparisons:
+ def __le__(self, other):
+ return self._comparison(other, operator.le)
+
+ def __lt__(self, other):
+ return self._comparison(other, operator.lt)
+
+ def __ge__(self, other):
+ return self._comparison(other, operator.ge)
+
+ def __gt__(self, other):
+ return self._comparison(other, operator.gt)
+
def __add__(self, other):
"""
Add self to other, and return a new masked array.
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 4fac897de..b056d5169 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1756,6 +1756,52 @@ class TestMaskedArrayArithmetic:
assert_equal(test.mask, [True, False])
assert_(test.fill_value == True)
+ @pytest.mark.parametrize('dt1', num_dts, ids=num_ids)
+ @pytest.mark.parametrize('dt2', num_dts, ids=num_ids)
+ @pytest.mark.parametrize('fill', [None, 1])
+ @pytest.mark.parametrize('op',
+ [operator.le, operator.lt, operator.ge, operator.gt])
+ def test_comparisons_for_numeric(self, op, dt1, dt2, fill):
+ # Test the equality of structured arrays
+ a = array([0, 1], dtype=dt1, mask=[0, 1], fill_value=fill)
+
+ test = op(a, a)
+ assert_equal(test.data, op(a._data, a._data))
+ assert_equal(test.mask, [False, True])
+ assert_(test.fill_value == True)
+
+ test = op(a, a[0])
+ assert_equal(test.data, op(a._data, a._data[0]))
+ assert_equal(test.mask, [False, True])
+ assert_(test.fill_value == True)
+
+ b = array([0, 1], dtype=dt2, mask=[1, 0], fill_value=fill)
+ test = op(a, b)
+ assert_equal(test.data, op(a._data, b._data))
+ assert_equal(test.mask, [True, True])
+ assert_(test.fill_value == True)
+
+ test = op(a[0], b)
+ assert_equal(test.data, op(a._data[0], b._data))
+ assert_equal(test.mask, [True, False])
+ assert_(test.fill_value == True)
+
+ test = op(b, a[0])
+ assert_equal(test.data, op(b._data, a._data[0]))
+ assert_equal(test.mask, [True, False])
+ assert_(test.fill_value == True)
+
+ @pytest.mark.parametrize('op',
+ [operator.le, operator.lt, operator.ge, operator.gt])
+ @pytest.mark.parametrize('fill', [None, "N/A"])
+ def test_comparisons_strings(self, op, fill):
+ # See gh-21770, mask propagation is broken for strings (and some other
+ # cases) so we explicitly test strings here.
+ # In principle only == and != may need special handling...
+ ma1 = masked_array(["a", "b", "cde"], mask=[0, 1, 0], fill_value=fill)
+ ma2 = masked_array(["cde", "b", "a"], mask=[0, 1, 0], fill_value=fill)
+ assert_equal(op(ma1, ma2)._data, op(ma1._data, ma2._data))
+
def test_eq_with_None(self):
# Really, comparisons with None should not be done, but check them
# anyway. Note that pep8 will flag these tests.