diff options
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r-- | numpy/ma/core.py | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index d8fd4f389..93eb74be3 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -4102,6 +4102,10 @@ class MaskedArray(ndarray): odata = getdata(other) if mask.dtype.names is not None: + # only == and != are reasonably defined for structured dtypes, + # so give up early for all other comparisons: + if compare not in (operator.eq, operator.ne): + return NotImplemented # For possibly masked structured arrays we need to be careful, # since the standard structured array comparison will use all # fields, masked or not. To avoid masked fields influencing the @@ -4124,10 +4128,11 @@ class MaskedArray(ndarray): if isinstance(check, (np.bool_, bool)): return masked if mask else check - if mask is not nomask: + if mask is not nomask and compare in (operator.eq, operator.ne): # Adjust elements that were masked, which should be treated # as equal if masked in both, unequal if masked in one. # Note that this works automatically for structured arrays too. + # Ignore this for operations other than `==` and `!=` check = np.where(mask, compare(smask, omask), check) if mask.shape != check.shape: # Guarantee consistency of the shape, making a copy since the @@ -4175,6 +4180,19 @@ class MaskedArray(ndarray): """ return self._comparison(other, operator.ne) + # All other comparisons: + def __le__(self, other): + return self._comparison(other, operator.le) + + def __lt__(self, other): + return self._comparison(other, operator.lt) + + def __ge__(self, other): + return self._comparison(other, operator.ge) + + def __gt__(self, other): + return self._comparison(other, operator.gt) + def __add__(self, other): """ Add self to other, and return a new masked array. |