summaryrefslogtreecommitdiff
path: root/numpy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing')
-rw-r--r--numpy/testing/tests/test_utils.py36
-rw-r--r--numpy/testing/utils.py21
2 files changed, 40 insertions, 17 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 4ca6c6354..a05fc3bdb 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -299,9 +299,24 @@ class TestArrayAlmostEqual(_GenericTest, unittest.TestCase):
a = np.array([[1., 2.], [3., 4.]])
b = np.ma.masked_array([[1., 2.], [0., 4.]],
[[False, False], [True, False]])
- assert_array_almost_equal(a, b)
- assert_array_almost_equal(b, a)
- assert_array_almost_equal(b, b)
+ self._assert_func(a, b)
+ self._assert_func(b, a)
+ self._assert_func(b, b)
+
+ def test_subclass_that_cannot_be_bool(self):
+ # While we cannot guarantee testing functions will always work for
+ # subclasses, the tests should ideally rely only on subclasses having
+ # comparison operators, not on them being able to store booleans
+ # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
+ class MyArray(np.ndarray):
+ def __lt__(self, other):
+ return super(MyArray, self).__lt__(other).view(np.ndarray)
+
+ def all(self, *args, **kwargs):
+ raise NotImplementedError
+
+ a = np.array([1., 2.]).view(MyArray)
+ self._assert_func(a, a)
class TestAlmostEqual(_GenericTest, unittest.TestCase):
@@ -387,6 +402,21 @@ class TestAlmostEqual(_GenericTest, unittest.TestCase):
# remove anything that's not the array string
self.assertEqual(str(e).split('%)\n ')[1], b)
+ def test_subclass_that_cannot_be_bool(self):
+ # While we cannot guarantee testing functions will always work for
+ # subclasses, the tests should ideally rely only on subclasses having
+ # comparison operators, not on them being able to store booleans
+ # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
+ class MyArray(np.ndarray):
+ def __lt__(self, other):
+ return super(MyArray, self).__lt__(other).view(np.ndarray)
+
+ def all(self, *args, **kwargs):
+ raise NotImplementedError
+
+ a = np.array([1., 2.]).view(MyArray)
+ self._assert_func(a, a)
+
class TestApproxEqual(unittest.TestCase):
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 7858eefac..a44a51c81 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -669,8 +669,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header='', precision=6, equal_nan=True,
equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
- from numpy.core import array, isnan, isinf, any, all, inf, zeros_like
- from numpy.core.numerictypes import bool_
+ from numpy.core import array, isnan, isinf, any, inf
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -726,14 +725,13 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
raise AssertionError(msg)
if isnumber(x) and isnumber(y):
- x_id, y_id = zeros_like(x, dtype=bool_), zeros_like(y, dtype=bool_)
if equal_nan:
x_isnan, y_isnan = isnan(x), isnan(y)
# Validate that NaNs are in the same place
if any(x_isnan) or any(y_isnan):
chk_same_position(x_isnan, y_isnan, hasval='nan')
- x_id |= x_isnan
- y_id |= y_isnan
+ x = x[~x_isnan]
+ y = y[~y_isnan]
if equal_inf:
x_isinf, y_isinf = isinf(x), isinf(y)
@@ -742,19 +740,14 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
# Check +inf and -inf separately, since they are different
chk_same_position(x == +inf, y == +inf, hasval='+inf')
chk_same_position(x == -inf, y == -inf, hasval='-inf')
- x_id |= x_isinf
- y_id |= y_isinf
+ x = x[~x_isinf]
+ y = y[~y_isinf]
# Only do the comparison if actual values are left
- if all(x_id):
+ if x.size == 0:
return
- if any(x_id):
- val = safe_comparison(x[~x_id], y[~y_id])
- else:
- val = safe_comparison(x, y)
- else:
- val = safe_comparison(x, y)
+ val = safe_comparison(x, y)
if isinstance(val, bool):
cond = val