summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-01-14 16:32:57 -0700
committerCharles Harris <charlesr.harris@gmail.com>2016-01-14 16:32:57 -0700
commit7141f40b58ed1e7071cde78ab7bc8ab37e9c5983 (patch)
tree856886f3b9d65fb6305f21f9d692cb0861b861cb /numpy/core
parent8fa6e3bef26a6d4a2c92f2824129aa4409be2590 (diff)
parent53ad26a84ac2aa6f5a37f09aa9feae5afed44f79 (diff)
downloadnumpy-7141f40b58ed1e7071cde78ab7bc8ab37e9c5983.tar.gz
Merge pull request #7001 from shoyer/NaT-comparison
API: make all comparisons with NaT false
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/arrayprint.py6
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src2
-rw-r--r--numpy/core/src/umath/loops.c.src26
-rw-r--r--numpy/core/tests/test_datetime.py36
4 files changed, 57 insertions, 13 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index fefcb6493..c5b5b5a8f 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -739,8 +739,8 @@ class DatetimeFormat(object):
class TimedeltaFormat(object):
def __init__(self, data):
if data.dtype.kind == 'm':
- nat_value = array(['NaT'], dtype=data.dtype)[0]
- v = data[not_equal(data, nat_value)].view('i8')
+ # select non-NaT elements
+ v = data[data == data].view('i8')
if len(v) > 0:
# Max str length of non-NaT elements
max_str_len = max(len(str(maximum.reduce(v))),
@@ -754,7 +754,7 @@ class TimedeltaFormat(object):
self._nat = "'NaT'".rjust(max_str_len)
def __call__(self, x):
- if x + 1 == x:
+ if x != x:
return self._nat
else:
return self.format % x.astype('i8')
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 1bd5b22d2..7c73822dd 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1673,7 +1673,7 @@ voidtype_setfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds)
* However, as a special case, void-scalar assignment broadcasts
* differently from ndarrays when assigning to an object field: Assignment
* to an ndarray object field broadcasts, but assignment to a void-scalar
- * object-field should not, in order to allow nested ndarrays.
+ * object-field should not, in order to allow nested ndarrays.
* These lines should then behave identically:
*
* b = np.zeros(1, dtype=[('x', 'O')])
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index fc9ffec94..563761bc0 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -1117,8 +1117,8 @@ NPY_NO_EXPORT void
}
/**begin repeat1
- * #kind = equal, not_equal, greater, greater_equal, less, less_equal#
- * #OP = ==, !=, >, >=, <, <=#
+ * #kind = equal, greater, greater_equal, less, less_equal#
+ * #OP = ==, >, >=, <, <=#
*/
NPY_NO_EXPORT void
@TYPE@_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
@@ -1126,11 +1126,31 @@ NPY_NO_EXPORT void
BINARY_LOOP {
const @type@ in1 = *(@type@ *)ip1;
const @type@ in2 = *(@type@ *)ip2;
- *((npy_bool *)op1) = in1 @OP@ in2;
+ if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) {
+ *((npy_bool *)op1) = NPY_FALSE;
+ }
+ else {
+ *((npy_bool *)op1) = in1 @OP@ in2;
+ }
}
}
/**end repeat1**/
+NPY_NO_EXPORT void
+@TYPE@_not_equal(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
+{
+ BINARY_LOOP {
+ const @type@ in1 = *(@type@ *)ip1;
+ const @type@ in2 = *(@type@ *)ip2;
+ if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) {
+ *((npy_bool *)op1) = NPY_TRUE;
+ }
+ else {
+ *((npy_bool *)op1) = in1 != in2;
+ }
+ }
+}
+
/**begin repeat1
* #kind = maximum, minimum#
* #OP = >, <#
diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py
index 360463d38..65b1d460a 100644
--- a/numpy/core/tests/test_datetime.py
+++ b/numpy/core/tests/test_datetime.py
@@ -130,10 +130,11 @@ class TestDateTime(TestCase):
# regression tests for GH6452
assert_equal(np.datetime64('NaT'),
np.datetime64('2000') + np.timedelta64('NaT'))
- # nb. we may want to make NaT != NaT true in the future; this test
- # verifies the existing behavior (and that it should not warn)
- assert_(np.datetime64('NaT') == np.datetime64('NaT', 'us'))
- assert_(np.datetime64('NaT', 'us') == np.datetime64('NaT'))
+ assert_equal(np.datetime64('NaT'), np.datetime64('NaT', 'us'))
+ assert_equal(np.timedelta64('NaT'), np.timedelta64('NaT', 'us'))
+ # neither of these should issue a warning
+ assert_(np.datetime64('NaT') != np.datetime64('NaT', 'us'))
+ assert_(np.datetime64('NaT', 'us') != np.datetime64('NaT'))
def test_datetime_scalar_construction(self):
# Construct with different units
@@ -552,6 +553,9 @@ class TestDateTime(TestCase):
"'%s'" % np.datetime_as_string(x, timezone='UTC')}),
"['2011-03-16T13:55Z', '1920-01-01T03:12Z']")
+ a = np.array(['NaT', 'NaT'], dtype='datetime64[ns]')
+ assert_equal(str(a), "['NaT' 'NaT']")
+
# Check that one NaT doesn't corrupt subsequent entries
a = np.array(['2010', 'NaT', '2030']).astype('M')
assert_equal(str(a), "['2010' 'NaT' '2030']")
@@ -658,7 +662,7 @@ class TestDateTime(TestCase):
b[8] = 'NaT'
assert_equal(b.astype(object).astype(unit), b,
- "Error roundtripping unit %s" % unit)
+ "Error roundtripping unit %s" % unit)
# With time units
for unit in ['M8[as]', 'M8[16fs]', 'M8[ps]', 'M8[us]',
'M8[300as]', 'M8[20us]']:
@@ -674,7 +678,7 @@ class TestDateTime(TestCase):
b[8] = 'NaT'
assert_equal(b.astype(object).astype(unit), b,
- "Error roundtripping unit %s" % unit)
+ "Error roundtripping unit %s" % unit)
def test_month_truncation(self):
# Make sure that months are truncating correctly
@@ -1081,6 +1085,26 @@ class TestDateTime(TestCase):
assert_equal(np.greater(a, b), [0, 1, 0, 1, 0])
assert_equal(np.greater_equal(a, b), [1, 1, 0, 1, 0])
+ def test_datetime_compare_nat(self):
+ dt_nat = np.datetime64('NaT', 'D')
+ dt_other = np.datetime64('2000-01-01')
+ td_nat = np.timedelta64('NaT', 'h')
+ td_other = np.timedelta64(1, 'h')
+ for op in [np.equal, np.less, np.less_equal,
+ np.greater, np.greater_equal]:
+ assert_(not op(dt_nat, dt_nat))
+ assert_(not op(dt_nat, dt_other))
+ assert_(not op(dt_other, dt_nat))
+ assert_(not op(td_nat, td_nat))
+ assert_(not op(td_nat, td_other))
+ assert_(not op(td_other, td_nat))
+ assert_(np.not_equal(dt_nat, dt_nat))
+ assert_(np.not_equal(dt_nat, dt_other))
+ assert_(np.not_equal(dt_other, dt_nat))
+ assert_(np.not_equal(td_nat, td_nat))
+ assert_(np.not_equal(td_nat, td_other))
+ assert_(np.not_equal(td_other, td_nat))
+
def test_datetime_minmax(self):
# The metadata of the result should become the GCD
# of the operand metadata