diff options
Diffstat (limited to 'numpy/testing/tests/test_utils.py')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 37 |
1 files changed, 36 insertions, 1 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 804f22b7f..e2c105245 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -3,6 +3,7 @@ from __future__ import division, absolute_import, print_function import warnings import sys import os +import itertools import numpy as np from numpy.testing import ( @@ -144,7 +145,10 @@ class TestArrayEqual(_GenericTest, unittest.TestCase): c['floupipi'] = a['floupi'].copy() c['floupa'] = a['floupa'].copy() - self._test_not_equal(c, b) + with suppress_warnings() as sup: + l = sup.record(FutureWarning, message="elementwise == ") + self._test_not_equal(c, b) + assert_(len(l) == 1) class TestBuildErrorMessage(unittest.TestCase): @@ -208,6 +212,37 @@ class TestEqual(TestArrayEqual): self._assert_func([np.inf], [np.inf]) self._test_not_equal(np.inf, [np.inf]) + def test_nat_items(self): + # not a datetime + nadt_no_unit = np.datetime64("NaT") + nadt_s = np.datetime64("NaT", "s") + nadt_d = np.datetime64("NaT", "ns") + # not a timedelta + natd_no_unit = np.timedelta64("NaT") + natd_s = np.timedelta64("NaT", "s") + natd_d = np.timedelta64("NaT", "ns") + + dts = [nadt_no_unit, nadt_s, nadt_d] + tds = [natd_no_unit, natd_s, natd_d] + for a, b in itertools.product(dts, dts): + self._assert_func(a, b) + self._assert_func([a], [b]) + self._test_not_equal([a], b) + + for a, b in itertools.product(tds, tds): + self._assert_func(a, b) + self._assert_func([a], [b]) + self._test_not_equal([a], b) + + for a, b in itertools.product(tds, dts): + self._test_not_equal(a, b) + self._test_not_equal(a, [b]) + self._test_not_equal([a], [b]) + self._test_not_equal([a], np.datetime64("2017-01-01", "s")) + self._test_not_equal([b], np.datetime64("2017-01-01", "s")) + self._test_not_equal([a], np.timedelta64(123, "s")) + self._test_not_equal([b], np.timedelta64(123, "s")) + def test_non_numeric(self): self._assert_func('ab', 'ab') self._test_not_equal('ab', 'abb') |