summaryrefslogtreecommitdiff
path: root/numpy/ma/core.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r--numpy/ma/core.py139
1 files changed, 72 insertions, 67 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 3b2b39b18..35d4b72bc 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -23,6 +23,7 @@ Released for unlimited redistribution.
from __future__ import division, absolute_import, print_function
import sys
+import operator
import warnings
from functools import reduce
@@ -1733,7 +1734,8 @@ def mask_or(m1, m2, copy=False, shrink=True):
if (dtype1 != dtype2):
raise ValueError("Incompatible dtypes '%s'<>'%s'" % (dtype1, dtype2))
if dtype1.names:
- newmask = np.empty_like(m1)
+ # Allocate an output mask array with the properly broadcast shape.
+ newmask = np.empty(np.broadcast(m1, m2).shape, dtype1)
_recursive_mask_or(m1, m2, newmask)
return newmask
return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink)
@@ -3873,81 +3875,84 @@ class MaskedArray(ndarray):
return True
return False
- def __eq__(self, other):
- """
- Check whether other equals self elementwise.
+ def _comparison(self, other, compare):
+ """Compare self with other using operator.eq or operator.ne.
+ When either of the elements is masked, the result is masked as well,
+ but the underlying boolean data are still set, with self and other
+ considered equal if both are masked, and unequal otherwise.
+
+ For structured arrays, all fields are combined, with masked values
+ ignored. The result is masked if all fields were masked, with self
+ and other considered equal only if both were fully masked.
"""
- if self is masked:
- return masked
omask = getmask(other)
- if omask is nomask:
- check = self.filled(0).__eq__(other)
- try:
- check = check.view(type(self))
- check._mask = self._mask
- except AttributeError:
- # Dang, we have a bool instead of an array: return the bool
- return check
+ smask = self.mask
+ mask = mask_or(smask, omask, copy=True)
+
+ odata = getdata(other)
+ if mask.dtype.names:
+ # For possibly masked structured arrays we need to be careful,
+ # since the standard structured array comparison will use all
+ # fields, masked or not. To avoid masked fields influencing the
+ # outcome, we set all masked fields in self to other, so they'll
+ # count as equal. To prepare, we ensure we have the right shape.
+ broadcast_shape = np.broadcast(self, odata).shape
+ sbroadcast = np.broadcast_to(self, broadcast_shape, subok=True)
+ sbroadcast._mask = mask
+ sdata = sbroadcast.filled(odata)
+ # Now take care of the mask; the merged mask should have an item
+ # masked if all fields were masked (in one and/or other).
+ mask = (mask == np.ones((), mask.dtype))
+
else:
- odata = filled(other, 0)
- check = self.filled(0).__eq__(odata).view(type(self))
- if self._mask is nomask:
- check._mask = omask
- else:
- mask = mask_or(self._mask, omask)
- if mask.dtype.names:
- if mask.size > 1:
- axis = 1
- else:
- axis = None
- try:
- mask = mask.view((bool_, len(self.dtype))).all(axis)
- except (ValueError, np.AxisError):
- # TODO: what error are we trying to catch here?
- # invalid axis, or invalid view?
- mask = np.all([[f[n].all() for n in mask.dtype.names]
- for f in mask], axis=axis)
- check._mask = mask
+ # For regular arrays, just use the data as they come.
+ sdata = self.data
+
+ check = compare(sdata, odata)
+
+ if isinstance(check, (np.bool_, bool)):
+ return masked if mask else check
+
+ if mask is not nomask:
+ # Adjust elements that were masked, which should be treated
+ # as equal if masked in both, unequal if masked in one.
+ # Note that this works automatically for structured arrays too.
+ check = np.where(mask, compare(smask, omask), check)
+ if mask.shape != check.shape:
+ # Guarantee consistency of the shape, making a copy since the
+ # the mask may need to get written to later.
+ mask = np.broadcast_to(mask, check.shape).copy()
+
+ check = check.view(type(self))
+ check._mask = mask
return check
- def __ne__(self, other):
+ def __eq__(self, other):
+ """Check whether other equals self elementwise.
+
+ When either of the elements is masked, the result is masked as well,
+ but the underlying boolean data are still set, with self and other
+ considered equal if both are masked, and unequal otherwise.
+
+ For structured arrays, all fields are combined, with masked values
+ ignored. The result is masked if all fields were masked, with self
+ and other considered equal only if both were fully masked.
"""
- Check whether other doesn't equal self elementwise
+ return self._comparison(other, operator.eq)
+
+ def __ne__(self, other):
+ """Check whether other does not equal self elementwise.
+
+ When either of the elements is masked, the result is masked as well,
+ but the underlying boolean data are still set, with self and other
+ considered equal if both are masked, and unequal otherwise.
+ For structured arrays, all fields are combined, with masked values
+ ignored. The result is masked if all fields were masked, with self
+ and other considered equal only if both were fully masked.
"""
- if self is masked:
- return masked
- omask = getmask(other)
- if omask is nomask:
- check = self.filled(0).__ne__(other)
- try:
- check = check.view(type(self))
- check._mask = self._mask
- except AttributeError:
- # In case check is a boolean (or a numpy.bool)
- return check
- else:
- odata = filled(other, 0)
- check = self.filled(0).__ne__(odata).view(type(self))
- if self._mask is nomask:
- check._mask = omask
- else:
- mask = mask_or(self._mask, omask)
- if mask.dtype.names:
- if mask.size > 1:
- axis = 1
- else:
- axis = None
- try:
- mask = mask.view((bool_, len(self.dtype))).all(axis)
- except (ValueError, np.AxisError):
- # TODO: what error are we trying to catch here?
- # invalid axis, or invalid view?
- mask = np.all([[f[n].all() for n in mask.dtype.names]
- for f in mask], axis=axis)
- check._mask = mask
- return check
+ return self._comparison(other, operator.ne)
def __add__(self, other):
"""