summaryrefslogtreecommitdiff
path: root/numpy/ma/core.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r--numpy/ma/core.py20
1 files changed, 19 insertions, 1 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index d8fd4f389..93eb74be3 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -4102,6 +4102,10 @@ class MaskedArray(ndarray):
odata = getdata(other)
if mask.dtype.names is not None:
+ # only == and != are reasonably defined for structured dtypes,
+ # so give up early for all other comparisons:
+ if compare not in (operator.eq, operator.ne):
+ return NotImplemented
# For possibly masked structured arrays we need to be careful,
# since the standard structured array comparison will use all
# fields, masked or not. To avoid masked fields influencing the
@@ -4124,10 +4128,11 @@ class MaskedArray(ndarray):
if isinstance(check, (np.bool_, bool)):
return masked if mask else check
- if mask is not nomask:
+ if mask is not nomask and compare in (operator.eq, operator.ne):
# Adjust elements that were masked, which should be treated
# as equal if masked in both, unequal if masked in one.
# Note that this works automatically for structured arrays too.
+ # Ignore this for operations other than `==` and `!=`
check = np.where(mask, compare(smask, omask), check)
if mask.shape != check.shape:
# Guarantee consistency of the shape, making a copy since the
@@ -4175,6 +4180,19 @@ class MaskedArray(ndarray):
"""
return self._comparison(other, operator.ne)
+ # All other comparisons:
+ def __le__(self, other):
+ return self._comparison(other, operator.le)
+
+ def __lt__(self, other):
+ return self._comparison(other, operator.lt)
+
+ def __ge__(self, other):
+ return self._comparison(other, operator.ge)
+
+ def __gt__(self, other):
+ return self._comparison(other, operator.gt)
+
def __add__(self, other):
"""
Add self to other, and return a new masked array.