summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-02-08 22:17:49 -0500
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-02-27 11:24:17 -0500
commitf3f91be7920f52cd6abace64d7e14b876e5f054f (patch)
treee562c7ef7c2884cf059d2ddd8c618fd8c00d0464 /numpy
parentad8afe82e7b7643607a348c0e02b45c9131c6a06 (diff)
downloadnumpy-f3f91be7920f52cd6abace64d7e14b876e5f054f.tar.gz
BUG: MaskedArray __eq__ wrong for masked scalar, multi-d recarray
In the process of trying to fix the "questionable behaviour in `MaskedArray.__eq__`" (gh-8589), it became clear that the code was buggy. E.g., `ma == ma[0]` failed if `ma` held a structured dtype; multi-d structured dtypes failed generally; and, more worryingly, a masked scalar comparison could be wrong: `np.ma.MaskedArray(1, mask=True) == 0` yields True. This commit solves these problems, adding tests to prevent regression. In the process, it also ensures that the results for structured arrays always equals what one would get by logically combining the results over individual parts of the structure.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py139
-rw-r--r--numpy/ma/tests/test_core.py89
2 files changed, 155 insertions, 73 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 3b2b39b18..35d4b72bc 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -23,6 +23,7 @@ Released for unlimited redistribution.
from __future__ import division, absolute_import, print_function
import sys
+import operator
import warnings
from functools import reduce
@@ -1733,7 +1734,8 @@ def mask_or(m1, m2, copy=False, shrink=True):
if (dtype1 != dtype2):
raise ValueError("Incompatible dtypes '%s'<>'%s'" % (dtype1, dtype2))
if dtype1.names:
- newmask = np.empty_like(m1)
+ # Allocate an output mask array with the properly broadcast shape.
+ newmask = np.empty(np.broadcast(m1, m2).shape, dtype1)
_recursive_mask_or(m1, m2, newmask)
return newmask
return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink)
@@ -3873,81 +3875,84 @@ class MaskedArray(ndarray):
return True
return False
- def __eq__(self, other):
- """
- Check whether other equals self elementwise.
+ def _comparison(self, other, compare):
+ """Compare self with other using operator.eq or operator.ne.
+ When either of the elements is masked, the result is masked as well,
+ but the underlying boolean data are still set, with self and other
+ considered equal if both are masked, and unequal otherwise.
+
+ For structured arrays, all fields are combined, with masked values
+ ignored. The result is masked if all fields were masked, with self
+ and other considered equal only if both were fully masked.
"""
- if self is masked:
- return masked
omask = getmask(other)
- if omask is nomask:
- check = self.filled(0).__eq__(other)
- try:
- check = check.view(type(self))
- check._mask = self._mask
- except AttributeError:
- # Dang, we have a bool instead of an array: return the bool
- return check
+ smask = self.mask
+ mask = mask_or(smask, omask, copy=True)
+
+ odata = getdata(other)
+ if mask.dtype.names:
+ # For possibly masked structured arrays we need to be careful,
+ # since the standard structured array comparison will use all
+ # fields, masked or not. To avoid masked fields influencing the
+ # outcome, we set all masked fields in self to other, so they'll
+ # count as equal. To prepare, we ensure we have the right shape.
+ broadcast_shape = np.broadcast(self, odata).shape
+ sbroadcast = np.broadcast_to(self, broadcast_shape, subok=True)
+ sbroadcast._mask = mask
+ sdata = sbroadcast.filled(odata)
+ # Now take care of the mask; the merged mask should have an item
+ # masked if all fields were masked (in one and/or other).
+ mask = (mask == np.ones((), mask.dtype))
+
else:
- odata = filled(other, 0)
- check = self.filled(0).__eq__(odata).view(type(self))
- if self._mask is nomask:
- check._mask = omask
- else:
- mask = mask_or(self._mask, omask)
- if mask.dtype.names:
- if mask.size > 1:
- axis = 1
- else:
- axis = None
- try:
- mask = mask.view((bool_, len(self.dtype))).all(axis)
- except (ValueError, np.AxisError):
- # TODO: what error are we trying to catch here?
- # invalid axis, or invalid view?
- mask = np.all([[f[n].all() for n in mask.dtype.names]
- for f in mask], axis=axis)
- check._mask = mask
+ # For regular arrays, just use the data as they come.
+ sdata = self.data
+
+ check = compare(sdata, odata)
+
+ if isinstance(check, (np.bool_, bool)):
+ return masked if mask else check
+
+ if mask is not nomask:
+ # Adjust elements that were masked, which should be treated
+ # as equal if masked in both, unequal if masked in one.
+ # Note that this works automatically for structured arrays too.
+ check = np.where(mask, compare(smask, omask), check)
+ if mask.shape != check.shape:
+ # Guarantee consistency of the shape, making a copy since the
+ # the mask may need to get written to later.
+ mask = np.broadcast_to(mask, check.shape).copy()
+
+ check = check.view(type(self))
+ check._mask = mask
return check
- def __ne__(self, other):
+ def __eq__(self, other):
+ """Check whether other equals self elementwise.
+
+ When either of the elements is masked, the result is masked as well,
+ but the underlying boolean data are still set, with self and other
+ considered equal if both are masked, and unequal otherwise.
+
+ For structured arrays, all fields are combined, with masked values
+ ignored. The result is masked if all fields were masked, with self
+ and other considered equal only if both were fully masked.
"""
- Check whether other doesn't equal self elementwise
+ return self._comparison(other, operator.eq)
+
+ def __ne__(self, other):
+ """Check whether other does not equal self elementwise.
+
+ When either of the elements is masked, the result is masked as well,
+ but the underlying boolean data are still set, with self and other
+ considered equal if both are masked, and unequal otherwise.
+ For structured arrays, all fields are combined, with masked values
+ ignored. The result is masked if all fields were masked, with self
+ and other considered equal only if both were fully masked.
"""
- if self is masked:
- return masked
- omask = getmask(other)
- if omask is nomask:
- check = self.filled(0).__ne__(other)
- try:
- check = check.view(type(self))
- check._mask = self._mask
- except AttributeError:
- # In case check is a boolean (or a numpy.bool)
- return check
- else:
- odata = filled(other, 0)
- check = self.filled(0).__ne__(odata).view(type(self))
- if self._mask is nomask:
- check._mask = omask
- else:
- mask = mask_or(self._mask, omask)
- if mask.dtype.names:
- if mask.size > 1:
- axis = 1
- else:
- axis = None
- try:
- mask = mask.view((bool_, len(self.dtype))).all(axis)
- except (ValueError, np.AxisError):
- # TODO: what error are we trying to catch here?
- # invalid axis, or invalid view?
- mask = np.all([[f[n].all() for n in mask.dtype.names]
- for f in mask], axis=axis)
- check._mask = mask
- return check
+ return self._comparison(other, operator.ne)
def __add__(self, other):
"""
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index f9d032f09..d64f1acdc 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1335,32 +1335,95 @@ class TestMaskedArrayArithmetic(TestCase):
ndtype = [('A', int), ('B', int)]
a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
test = (a == a)
- assert_equal(test, [True, True])
+ assert_equal(test.data, [True, True])
+ assert_equal(test.mask, [False, False])
+ test = (a == a[0])
+ assert_equal(test.data, [True, False])
assert_equal(test.mask, [False, False])
b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
test = (a == b)
- assert_equal(test, [False, True])
+ assert_equal(test.data, [False, True])
+ assert_equal(test.mask, [True, False])
+ test = (a[0] == b)
+ assert_equal(test.data, [False, False])
assert_equal(test.mask, [True, False])
b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
test = (a == b)
- assert_equal(test, [True, False])
+ assert_equal(test.data, [True, True])
assert_equal(test.mask, [False, False])
+ # complicated dtype, 2-dimensional array.
+ ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])]
+ a = array([[(1, (1, 1)), (2, (2, 2))],
+ [(3, (3, 3)), (4, (4, 4))]],
+ mask=[[(0, (1, 0)), (0, (0, 1))],
+ [(1, (0, 0)), (1, (1, 1))]], dtype=ndtype)
+ test = (a[0, 0] == a)
+ assert_equal(test.data, [[True, False], [False, False]])
+ assert_equal(test.mask, [[False, False], [False, True]])
def test_ne_on_structured(self):
# Test the equality of structured arrays
ndtype = [('A', int), ('B', int)]
a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
test = (a != a)
- assert_equal(test, [False, False])
+ assert_equal(test.data, [False, False])
+ assert_equal(test.mask, [False, False])
+ test = (a != a[0])
+ assert_equal(test.data, [False, True])
assert_equal(test.mask, [False, False])
b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
test = (a != b)
- assert_equal(test, [True, False])
+ assert_equal(test.data, [True, False])
+ assert_equal(test.mask, [True, False])
+ test = (a[0] != b)
+ assert_equal(test.data, [True, True])
assert_equal(test.mask, [True, False])
b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
test = (a != b)
- assert_equal(test, [False, True])
+ assert_equal(test.data, [False, False])
assert_equal(test.mask, [False, False])
+ # complicated dtype, 2-dimensional array.
+ ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])]
+ a = array([[(1, (1, 1)), (2, (2, 2))],
+ [(3, (3, 3)), (4, (4, 4))]],
+ mask=[[(0, (1, 0)), (0, (0, 1))],
+ [(1, (0, 0)), (1, (1, 1))]], dtype=ndtype)
+ test = (a[0, 0] != a)
+ assert_equal(test.data, [[False, True], [True, True]])
+ assert_equal(test.mask, [[False, False], [False, True]])
+
+ def test_eq_ne_structured_extra(self):
+ # ensure simple examples are symmetric and make sense.
+ # from https://github.com/numpy/numpy/pull/8590#discussion_r101126465
+ dt = np.dtype('i4,i4')
+ for m1 in (mvoid((1, 2), mask=(0, 0), dtype=dt),
+ mvoid((1, 2), mask=(0, 1), dtype=dt),
+ mvoid((1, 2), mask=(1, 0), dtype=dt),
+ mvoid((1, 2), mask=(1, 1), dtype=dt)):
+ ma1 = m1.view(MaskedArray)
+ r1 = ma1.view('2i4')
+ for m2 in (mvoid((1, 1), dtype=dt),
+ mvoid((1, 0), mask=(0, 1), dtype=dt),
+ mvoid((3, 2), mask=(0, 1), dtype=dt)):
+ ma2 = m2.view(MaskedArray)
+ r2 = ma2.view('2i4')
+ eq_expected = (r1 == r2).all()
+ assert_equal(m1 == m2, eq_expected)
+ assert_equal(m2 == m1, eq_expected)
+ assert_equal(ma1 == m2, eq_expected)
+ assert_equal(m1 == ma2, eq_expected)
+ assert_equal(ma1 == ma2, eq_expected)
+ # Also check it is the same if we do it element by element.
+ el_by_el = [m1[name] == m2[name] for name in dt.names]
+ assert_equal(array(el_by_el, dtype=bool).all(), eq_expected)
+ ne_expected = (r1 != r2).any()
+ assert_equal(m1 != m2, ne_expected)
+ assert_equal(m2 != m1, ne_expected)
+ assert_equal(ma1 != m2, ne_expected)
+ assert_equal(m1 != ma2, ne_expected)
+ assert_equal(ma1 != ma2, ne_expected)
+ el_by_el = [m1[name] != m2[name] for name in dt.names]
+ assert_equal(array(el_by_el, dtype=bool).any(), ne_expected)
def test_eq_with_None(self):
# Really, comparisons with None should not be done, but check them
@@ -1393,6 +1456,20 @@ class TestMaskedArrayArithmetic(TestCase):
assert_equal(a == 0, False)
assert_equal(a != 1, False)
assert_equal(a != 0, True)
+ b = array(1, mask=True)
+ assert_equal(b == 0, masked)
+ assert_equal(b == 1, masked)
+ assert_equal(b != 0, masked)
+ assert_equal(b != 1, masked)
+
+ def test_eq_different_dimensions(self):
+ m1 = array([[0, 1], [1, 2]])
+ m2 = array([1, 1], mask=[0, 1])
+ test = (m1 == m2)
+ assert_equal(test, [[False, False],
+ [True, False]])
+ assert_equal(test.mask, [[False, True],
+ [False, True]])
def test_numpyarithmetics(self):
# Check that the mask is not back-propagated when using numpy functions