diff options
Diffstat (limited to 'numpy/testing/tests/test_utils.py')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 4ca6c6354..a05fc3bdb 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -299,9 +299,24 @@ class TestArrayAlmostEqual(_GenericTest, unittest.TestCase): a = np.array([[1., 2.], [3., 4.]]) b = np.ma.masked_array([[1., 2.], [0., 4.]], [[False, False], [True, False]]) - assert_array_almost_equal(a, b) - assert_array_almost_equal(b, a) - assert_array_almost_equal(b, b) + self._assert_func(a, b) + self._assert_func(b, a) + self._assert_func(b, b) + + def test_subclass_that_cannot_be_bool(self): + # While we cannot guarantee testing functions will always work for + # subclasses, the tests should ideally rely only on subclasses having + # comparison operators, not on them being able to store booleans + # (which, e.g., astropy Quantity cannot usefully do). See gh-8452. + class MyArray(np.ndarray): + def __lt__(self, other): + return super(MyArray, self).__lt__(other).view(np.ndarray) + + def all(self, *args, **kwargs): + raise NotImplementedError + + a = np.array([1., 2.]).view(MyArray) + self._assert_func(a, a) class TestAlmostEqual(_GenericTest, unittest.TestCase): @@ -387,6 +402,21 @@ class TestAlmostEqual(_GenericTest, unittest.TestCase): # remove anything that's not the array string self.assertEqual(str(e).split('%)\n ')[1], b) + def test_subclass_that_cannot_be_bool(self): + # While we cannot guarantee testing functions will always work for + # subclasses, the tests should ideally rely only on subclasses having + # comparison operators, not on them being able to store booleans + # (which, e.g., astropy Quantity cannot usefully do). See gh-8452. + class MyArray(np.ndarray): + def __lt__(self, other): + return super(MyArray, self).__lt__(other).view(np.ndarray) + + def all(self, *args, **kwargs): + raise NotImplementedError + + a = np.array([1., 2.]).view(MyArray) + self._assert_func(a, a) + class TestApproxEqual(unittest.TestCase): |