diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2016-01-14 16:32:57 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2016-01-14 16:32:57 -0700 |
commit | 7141f40b58ed1e7071cde78ab7bc8ab37e9c5983 (patch) | |
tree | 856886f3b9d65fb6305f21f9d692cb0861b861cb /numpy/core | |
parent | 8fa6e3bef26a6d4a2c92f2824129aa4409be2590 (diff) | |
parent | 53ad26a84ac2aa6f5a37f09aa9feae5afed44f79 (diff) | |
download | numpy-7141f40b58ed1e7071cde78ab7bc8ab37e9c5983.tar.gz |
Merge pull request #7001 from shoyer/NaT-comparison
API: make all comparisons with NaT false
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/arrayprint.py | 6 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 2 | ||||
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 26 | ||||
-rw-r--r-- | numpy/core/tests/test_datetime.py | 36 |
4 files changed, 57 insertions, 13 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index fefcb6493..c5b5b5a8f 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -739,8 +739,8 @@ class DatetimeFormat(object): class TimedeltaFormat(object): def __init__(self, data): if data.dtype.kind == 'm': - nat_value = array(['NaT'], dtype=data.dtype)[0] - v = data[not_equal(data, nat_value)].view('i8') + # select non-NaT elements + v = data[data == data].view('i8') if len(v) > 0: # Max str length of non-NaT elements max_str_len = max(len(str(maximum.reduce(v))), @@ -754,7 +754,7 @@ class TimedeltaFormat(object): self._nat = "'NaT'".rjust(max_str_len) def __call__(self, x): - if x + 1 == x: + if x != x: return self._nat else: return self.format % x.astype('i8') diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 1bd5b22d2..7c73822dd 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1673,7 +1673,7 @@ voidtype_setfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds) * However, as a special case, void-scalar assignment broadcasts * differently from ndarrays when assigning to an object field: Assignment * to an ndarray object field broadcasts, but assignment to a void-scalar - * object-field should not, in order to allow nested ndarrays. + * object-field should not, in order to allow nested ndarrays. * These lines should then behave identically: * * b = np.zeros(1, dtype=[('x', 'O')]) diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index fc9ffec94..563761bc0 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -1117,8 +1117,8 @@ NPY_NO_EXPORT void } /**begin repeat1 - * #kind = equal, not_equal, greater, greater_equal, less, less_equal# - * #OP = ==, !=, >, >=, <, <=# + * #kind = equal, greater, greater_equal, less, less_equal# + * #OP = ==, >, >=, <, <=# */ NPY_NO_EXPORT void @TYPE@_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)) @@ -1126,11 +1126,31 @@ NPY_NO_EXPORT void BINARY_LOOP { const @type@ in1 = *(@type@ *)ip1; const @type@ in2 = *(@type@ *)ip2; - *((npy_bool *)op1) = in1 @OP@ in2; + if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) { + *((npy_bool *)op1) = NPY_FALSE; + } + else { + *((npy_bool *)op1) = in1 @OP@ in2; + } } } /**end repeat1**/ +NPY_NO_EXPORT void +@TYPE@_not_equal(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)) +{ + BINARY_LOOP { + const @type@ in1 = *(@type@ *)ip1; + const @type@ in2 = *(@type@ *)ip2; + if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) { + *((npy_bool *)op1) = NPY_TRUE; + } + else { + *((npy_bool *)op1) = in1 != in2; + } + } +} + /**begin repeat1 * #kind = maximum, minimum# * #OP = >, <# diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py index 360463d38..65b1d460a 100644 --- a/numpy/core/tests/test_datetime.py +++ b/numpy/core/tests/test_datetime.py @@ -130,10 +130,11 @@ class TestDateTime(TestCase): # regression tests for GH6452 assert_equal(np.datetime64('NaT'), np.datetime64('2000') + np.timedelta64('NaT')) - # nb. we may want to make NaT != NaT true in the future; this test - # verifies the existing behavior (and that it should not warn) - assert_(np.datetime64('NaT') == np.datetime64('NaT', 'us')) - assert_(np.datetime64('NaT', 'us') == np.datetime64('NaT')) + assert_equal(np.datetime64('NaT'), np.datetime64('NaT', 'us')) + assert_equal(np.timedelta64('NaT'), np.timedelta64('NaT', 'us')) + # neither of these should issue a warning + assert_(np.datetime64('NaT') != np.datetime64('NaT', 'us')) + assert_(np.datetime64('NaT', 'us') != np.datetime64('NaT')) def test_datetime_scalar_construction(self): # Construct with different units @@ -552,6 +553,9 @@ class TestDateTime(TestCase): "'%s'" % np.datetime_as_string(x, timezone='UTC')}), "['2011-03-16T13:55Z', '1920-01-01T03:12Z']") + a = np.array(['NaT', 'NaT'], dtype='datetime64[ns]') + assert_equal(str(a), "['NaT' 'NaT']") + # Check that one NaT doesn't corrupt subsequent entries a = np.array(['2010', 'NaT', '2030']).astype('M') assert_equal(str(a), "['2010' 'NaT' '2030']") @@ -658,7 +662,7 @@ class TestDateTime(TestCase): b[8] = 'NaT' assert_equal(b.astype(object).astype(unit), b, - "Error roundtripping unit %s" % unit) + "Error roundtripping unit %s" % unit) # With time units for unit in ['M8[as]', 'M8[16fs]', 'M8[ps]', 'M8[us]', 'M8[300as]', 'M8[20us]']: @@ -674,7 +678,7 @@ class TestDateTime(TestCase): b[8] = 'NaT' assert_equal(b.astype(object).astype(unit), b, - "Error roundtripping unit %s" % unit) + "Error roundtripping unit %s" % unit) def test_month_truncation(self): # Make sure that months are truncating correctly @@ -1081,6 +1085,26 @@ class TestDateTime(TestCase): assert_equal(np.greater(a, b), [0, 1, 0, 1, 0]) assert_equal(np.greater_equal(a, b), [1, 1, 0, 1, 0]) + def test_datetime_compare_nat(self): + dt_nat = np.datetime64('NaT', 'D') + dt_other = np.datetime64('2000-01-01') + td_nat = np.timedelta64('NaT', 'h') + td_other = np.timedelta64(1, 'h') + for op in [np.equal, np.less, np.less_equal, + np.greater, np.greater_equal]: + assert_(not op(dt_nat, dt_nat)) + assert_(not op(dt_nat, dt_other)) + assert_(not op(dt_other, dt_nat)) + assert_(not op(td_nat, td_nat)) + assert_(not op(td_nat, td_other)) + assert_(not op(td_other, td_nat)) + assert_(np.not_equal(dt_nat, dt_nat)) + assert_(np.not_equal(dt_nat, dt_other)) + assert_(np.not_equal(dt_other, dt_nat)) + assert_(np.not_equal(td_nat, td_nat)) + assert_(np.not_equal(td_nat, td_other)) + assert_(np.not_equal(td_other, td_nat)) + def test_datetime_minmax(self): # The metadata of the result should become the GCD # of the operand metadata |