diff options
Diffstat (limited to 'numpy/testing')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 37 | ||||
-rw-r--r-- | numpy/testing/utils.py | 76 |
2 files changed, 75 insertions, 38 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 804f22b7f..e2c105245 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -3,6 +3,7 @@ from __future__ import division, absolute_import, print_function import warnings import sys import os +import itertools import numpy as np from numpy.testing import ( @@ -144,7 +145,10 @@ class TestArrayEqual(_GenericTest, unittest.TestCase): c['floupipi'] = a['floupi'].copy() c['floupa'] = a['floupa'].copy() - self._test_not_equal(c, b) + with suppress_warnings() as sup: + l = sup.record(FutureWarning, message="elementwise == ") + self._test_not_equal(c, b) + assert_(len(l) == 1) class TestBuildErrorMessage(unittest.TestCase): @@ -208,6 +212,37 @@ class TestEqual(TestArrayEqual): self._assert_func([np.inf], [np.inf]) self._test_not_equal(np.inf, [np.inf]) + def test_nat_items(self): + # not a datetime + nadt_no_unit = np.datetime64("NaT") + nadt_s = np.datetime64("NaT", "s") + nadt_d = np.datetime64("NaT", "ns") + # not a timedelta + natd_no_unit = np.timedelta64("NaT") + natd_s = np.timedelta64("NaT", "s") + natd_d = np.timedelta64("NaT", "ns") + + dts = [nadt_no_unit, nadt_s, nadt_d] + tds = [natd_no_unit, natd_s, natd_d] + for a, b in itertools.product(dts, dts): + self._assert_func(a, b) + self._assert_func([a], [b]) + self._test_not_equal([a], b) + + for a, b in itertools.product(tds, tds): + self._assert_func(a, b) + self._assert_func([a], [b]) + self._test_not_equal([a], b) + + for a, b in itertools.product(tds, dts): + self._test_not_equal(a, b) + self._test_not_equal(a, [b]) + self._test_not_equal([a], [b]) + self._test_not_equal([a], np.datetime64("2017-01-01", "s")) + self._test_not_equal([b], np.datetime64("2017-01-01", "s")) + self._test_not_equal([a], np.timedelta64(123, "s")) + self._test_not_equal([b], np.timedelta64(123, "s")) + def test_non_numeric(self): self._assert_func('ab', 'ab') self._test_not_equal('ab', 'abb') diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index b5a7e05c4..f54995870 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -15,7 +15,8 @@ import contextlib from tempfile import mkdtemp, mkstemp from unittest.case import SkipTest -from numpy.core import float32, empty, arange, array_repr, ndarray +from numpy.core import( + float32, empty, arange, array_repr, ndarray, isnat, array) from numpy.lib.utils import deprecate if sys.version_info[0] >= 3: @@ -286,7 +287,7 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', return '\n'.join(msg) -def assert_equal(actual,desired,err_msg='',verbose=True): +def assert_equal(actual, desired, err_msg='', verbose=True): """ Raises an AssertionError if two objects are not equal. @@ -369,12 +370,12 @@ def assert_equal(actual,desired,err_msg='',verbose=True): except AssertionError: raise AssertionError(msg) + # isscalar test to check cases such as [np.nan] != np.nan + if isscalar(desired) != isscalar(actual): + raise AssertionError(msg) + # Inf/nan/negative zero handling try: - # isscalar test to check cases such as [np.nan] != np.nan - if isscalar(desired) != isscalar(actual): - raise AssertionError(msg) - # If one of desired/actual is not finite, handle it specially here: # check that both are nan if any is a nan, and test for equality # otherwise @@ -396,14 +397,24 @@ def assert_equal(actual,desired,err_msg='',verbose=True): except (TypeError, ValueError, NotImplementedError): pass - # Explicitly use __eq__ for comparison, ticket #2552 - with suppress_warnings() as sup: - # TODO: Better handling will to needed when change happens! - sup.filter(DeprecationWarning, ".*NAT ==") - sup.filter(FutureWarning, ".*NAT ==") - if not (desired == actual): + try: + # If both are NaT (and have the same dtype -- datetime or timedelta) + # they are considered equal. + if (isnat(desired) == isnat(actual) and + array(desired).dtype.type == array(actual).dtype.type): + return + else: raise AssertionError(msg) + # If TypeError or ValueError raised while using isnan and co, just handle + # as before + except (TypeError, ValueError, NotImplementedError): + pass + + # Explicitly use __eq__ for comparison, ticket #2552 + if not (desired == actual): + raise AssertionError(msg) + def print_assert_equal(test_string, actual, desired): """ @@ -674,33 +685,12 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) - def safe_comparison(*args, **kwargs): - # There are a number of cases where comparing two arrays hits special - # cases in array_richcompare, specifically around strings and void - # dtypes. Basically, we just can't do comparisons involving these - # types, unless both arrays have exactly the *same* type. So - # e.g. you can apply == to two string arrays, or two arrays with - # identical structured dtypes. But if you compare a non-string array - # to a string array, or two arrays with non-identical structured - # dtypes, or anything like that, then internally stuff blows up. - # Currently, when things blow up, we just return a scalar False or - # True. But we also emit a DeprecationWarning, b/c eventually we - # should raise an error here. (Ideally we might even make this work - # properly, but since that will require rewriting a bunch of how - # ufuncs work then we are not counting on that.) - # - # The point of this little function is to let the DeprecationWarning - # pass (or maybe eventually catch the errors and return False, I - # dunno, that's a little trickier and we can figure that out when the - # time comes). - with suppress_warnings() as sup: - sup.filter(DeprecationWarning, ".*==") - sup.filter(FutureWarning, ".*==") - return comparison(*args, **kwargs) - def isnumber(x): return x.dtype.char in '?bhilqpBHILQPefdgFDG' + def istime(x): + return x.dtype.char in "Mm" + def chk_same_position(x_id, y_id, hasval='nan'): """Handling nan/inf: check that x and y have the nan/inf at the same locations.""" @@ -756,7 +746,19 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, if x.size == 0: return - val = safe_comparison(x, y) + elif istime(x) and istime(y): + # If one is datetime64 and the other timedelta64 there is no point + if equal_nan and x.dtype.type == y.dtype.type: + x_isnat, y_isnat = isnat(x), isnat(y) + + if any(x_isnat) or any(y_isnat): + chk_same_position(x_isnat, y_isnat, hasval="NaT") + + if any(x_isnat) or any(y_isnat): + x = x[~x_isnat] + y = y[~y_isnat] + + val = comparison(x, y) if isinstance(val, bool): cond = val |