summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/arrayprint.py6
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src2
-rw-r--r--numpy/core/src/umath/loops.c.src26
-rw-r--r--numpy/core/tests/test_datetime.py36
-rw-r--r--numpy/ma/tests/test_extras.py2
-rw-r--r--numpy/ma/testutils.py5
-rw-r--r--numpy/testing/tests/test_utils.py33
-rw-r--r--numpy/testing/utils.py49
8 files changed, 135 insertions, 24 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index fefcb6493..c5b5b5a8f 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -739,8 +739,8 @@ class DatetimeFormat(object):
class TimedeltaFormat(object):
def __init__(self, data):
if data.dtype.kind == 'm':
- nat_value = array(['NaT'], dtype=data.dtype)[0]
- v = data[not_equal(data, nat_value)].view('i8')
+ # select non-NaT elements
+ v = data[data == data].view('i8')
if len(v) > 0:
# Max str length of non-NaT elements
max_str_len = max(len(str(maximum.reduce(v))),
@@ -754,7 +754,7 @@ class TimedeltaFormat(object):
self._nat = "'NaT'".rjust(max_str_len)
def __call__(self, x):
- if x + 1 == x:
+ if x != x:
return self._nat
else:
return self.format % x.astype('i8')
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 1bd5b22d2..7c73822dd 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1673,7 +1673,7 @@ voidtype_setfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds)
* However, as a special case, void-scalar assignment broadcasts
* differently from ndarrays when assigning to an object field: Assignment
* to an ndarray object field broadcasts, but assignment to a void-scalar
- * object-field should not, in order to allow nested ndarrays.
+ * object-field should not, in order to allow nested ndarrays.
* These lines should then behave identically:
*
* b = np.zeros(1, dtype=[('x', 'O')])
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index fc9ffec94..563761bc0 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -1117,8 +1117,8 @@ NPY_NO_EXPORT void
}
/**begin repeat1
- * #kind = equal, not_equal, greater, greater_equal, less, less_equal#
- * #OP = ==, !=, >, >=, <, <=#
+ * #kind = equal, greater, greater_equal, less, less_equal#
+ * #OP = ==, >, >=, <, <=#
*/
NPY_NO_EXPORT void
@TYPE@_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
@@ -1126,11 +1126,31 @@ NPY_NO_EXPORT void
BINARY_LOOP {
const @type@ in1 = *(@type@ *)ip1;
const @type@ in2 = *(@type@ *)ip2;
- *((npy_bool *)op1) = in1 @OP@ in2;
+ if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) {
+ *((npy_bool *)op1) = NPY_FALSE;
+ }
+ else {
+ *((npy_bool *)op1) = in1 @OP@ in2;
+ }
}
}
/**end repeat1**/
+NPY_NO_EXPORT void
+@TYPE@_not_equal(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
+{
+ BINARY_LOOP {
+ const @type@ in1 = *(@type@ *)ip1;
+ const @type@ in2 = *(@type@ *)ip2;
+ if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) {
+ *((npy_bool *)op1) = NPY_TRUE;
+ }
+ else {
+ *((npy_bool *)op1) = in1 != in2;
+ }
+ }
+}
+
/**begin repeat1
* #kind = maximum, minimum#
* #OP = >, <#
diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py
index 360463d38..65b1d460a 100644
--- a/numpy/core/tests/test_datetime.py
+++ b/numpy/core/tests/test_datetime.py
@@ -130,10 +130,11 @@ class TestDateTime(TestCase):
# regression tests for GH6452
assert_equal(np.datetime64('NaT'),
np.datetime64('2000') + np.timedelta64('NaT'))
- # nb. we may want to make NaT != NaT true in the future; this test
- # verifies the existing behavior (and that it should not warn)
- assert_(np.datetime64('NaT') == np.datetime64('NaT', 'us'))
- assert_(np.datetime64('NaT', 'us') == np.datetime64('NaT'))
+ assert_equal(np.datetime64('NaT'), np.datetime64('NaT', 'us'))
+ assert_equal(np.timedelta64('NaT'), np.timedelta64('NaT', 'us'))
+ # neither of these should issue a warning
+ assert_(np.datetime64('NaT') != np.datetime64('NaT', 'us'))
+ assert_(np.datetime64('NaT', 'us') != np.datetime64('NaT'))
def test_datetime_scalar_construction(self):
# Construct with different units
@@ -552,6 +553,9 @@ class TestDateTime(TestCase):
"'%s'" % np.datetime_as_string(x, timezone='UTC')}),
"['2011-03-16T13:55Z', '1920-01-01T03:12Z']")
+ a = np.array(['NaT', 'NaT'], dtype='datetime64[ns]')
+ assert_equal(str(a), "['NaT' 'NaT']")
+
# Check that one NaT doesn't corrupt subsequent entries
a = np.array(['2010', 'NaT', '2030']).astype('M')
assert_equal(str(a), "['2010' 'NaT' '2030']")
@@ -658,7 +662,7 @@ class TestDateTime(TestCase):
b[8] = 'NaT'
assert_equal(b.astype(object).astype(unit), b,
- "Error roundtripping unit %s" % unit)
+ "Error roundtripping unit %s" % unit)
# With time units
for unit in ['M8[as]', 'M8[16fs]', 'M8[ps]', 'M8[us]',
'M8[300as]', 'M8[20us]']:
@@ -674,7 +678,7 @@ class TestDateTime(TestCase):
b[8] = 'NaT'
assert_equal(b.astype(object).astype(unit), b,
- "Error roundtripping unit %s" % unit)
+ "Error roundtripping unit %s" % unit)
def test_month_truncation(self):
# Make sure that months are truncating correctly
@@ -1081,6 +1085,26 @@ class TestDateTime(TestCase):
assert_equal(np.greater(a, b), [0, 1, 0, 1, 0])
assert_equal(np.greater_equal(a, b), [1, 1, 0, 1, 0])
+ def test_datetime_compare_nat(self):
+ dt_nat = np.datetime64('NaT', 'D')
+ dt_other = np.datetime64('2000-01-01')
+ td_nat = np.timedelta64('NaT', 'h')
+ td_other = np.timedelta64(1, 'h')
+ for op in [np.equal, np.less, np.less_equal,
+ np.greater, np.greater_equal]:
+ assert_(not op(dt_nat, dt_nat))
+ assert_(not op(dt_nat, dt_other))
+ assert_(not op(dt_other, dt_nat))
+ assert_(not op(td_nat, td_nat))
+ assert_(not op(td_nat, td_other))
+ assert_(not op(td_other, td_nat))
+ assert_(np.not_equal(dt_nat, dt_nat))
+ assert_(np.not_equal(dt_nat, dt_other))
+ assert_(np.not_equal(dt_other, dt_nat))
+ assert_(np.not_equal(td_nat, td_nat))
+ assert_(np.not_equal(td_nat, td_other))
+ assert_(np.not_equal(td_other, td_nat))
+
def test_datetime_minmax(self):
# The metadata of the result should become the GCD
# of the operand metadata
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 6138d0573..c2428fa10 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -154,7 +154,7 @@ class TestAverage(TestCase):
ott = ott.reshape(2, 2)
ott[:, 1] = masked
assert_equal(average(ott, axis=0), [2.0, 0.0])
- assert_equal(average(ott, axis=1).mask[0], [True])
+ assert_equal(average(ott, axis=1).mask[0], True)
assert_equal([2., 0.], average(ott, axis=0))
result, wts = average(ott, axis=0, returned=1)
assert_equal(wts, [1., 0.])
diff --git a/numpy/ma/testutils.py b/numpy/ma/testutils.py
index 8dc821878..40b9fa1be 100644
--- a/numpy/ma/testutils.py
+++ b/numpy/ma/testutils.py
@@ -125,10 +125,7 @@ def assert_equal(actual, desired, err_msg=''):
if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
return _assert_equal_on_sequences(actual, desired, err_msg='')
if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)):
- msg = build_err_msg([actual, desired], err_msg,)
- if not desired == actual:
- raise AssertionError(msg)
- return
+ return utils.assert_equal(actual, desired)
# Case #4. arrays or equivalent
if ((actual is masked) and not (desired is masked)) or \
((desired is masked) and not (actual is masked)):
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 23bd491bc..92a00f712 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -11,7 +11,7 @@ from numpy.testing import (
assert_warns, assert_no_warnings, assert_allclose, assert_approx_equal,
assert_array_almost_equal_nulp, assert_array_max_ulp,
clear_and_catch_warnings, run_module_suite,
- assert_string_equal, assert_, tempdir, temppath,
+ assert_string_equal, assert_, tempdir, temppath,
)
import unittest
@@ -119,6 +119,25 @@ class TestArrayEqual(_GenericTest, unittest.TestCase):
c = np.array([1, 2, 3])
self._test_not_equal(c, b)
+ def test_nat_array_datetime(self):
+ a = np.array([np.datetime64('2000-01'), np.datetime64('NaT')])
+ b = np.array([np.datetime64('2000-01'), np.datetime64('NaT')])
+ self._test_equal(a, b)
+
+ c = np.array([np.datetime64('NaT'), np.datetime64('NaT')])
+ self._test_not_equal(c, b)
+
+ def test_nat_array_timedelta(self):
+ a = np.array([np.timedelta64(1, 'h'), np.timedelta64('NaT')])
+ b = np.array([np.timedelta64(1, 'h'), np.timedelta64('NaT')])
+ self._test_equal(a, b)
+
+ c = np.array([np.timedelta64('NaT'), np.timedelta64('NaT')])
+ self._test_not_equal(c, b)
+
+ d = np.array([np.datetime64('NaT'), np.datetime64('NaT')])
+ self._test_not_equal(c, d)
+
def test_string_arrays(self):
"""Test two arrays with different shapes are found not equal."""
a = np.array(['floupi', 'floupa'])
@@ -227,6 +246,16 @@ class TestEqual(TestArrayEqual):
self._assert_func(x, x)
self._test_not_equal(x, y)
+ def test_nat(self):
+ dt = np.datetime64('2000-01-01')
+ dt_nat = np.datetime64('NaT')
+ td_nat = np.timedelta64('NaT')
+ self._assert_func(dt_nat, dt_nat)
+ self._assert_func(td_nat, td_nat)
+ self._test_not_equal(dt_nat, td_nat)
+ self._test_not_equal(dt, td_nat)
+ self._test_not_equal(dt, dt_nat)
+
class TestArrayAlmostEqual(_GenericTest, unittest.TestCase):
@@ -457,7 +486,7 @@ class TestWarns(unittest.TestCase):
class TestAssertAllclose(unittest.TestCase):
-
+
def test_simple(self):
x = 1e-3
y = 1e-9
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index f545cd3c2..8e71a3399 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -15,7 +15,7 @@ import contextlib
from tempfile import mkdtemp, mkstemp
from .nosetester import import_nose
-from numpy.core import float32, empty, arange, array_repr, ndarray
+from numpy.core import float32, empty, arange, array_repr, ndarray, dtype
from numpy.lib.utils import deprecate
if sys.version_info[0] >= 3:
@@ -343,16 +343,31 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
except AssertionError:
raise AssertionError(msg)
+ def isnat(x):
+ return (hasattr(x, 'dtype')
+ and getattr(x.dtype, 'kind', '_') in 'mM'
+ and x != x)
+
# Inf/nan/negative zero handling
try:
# isscalar test to check cases such as [np.nan] != np.nan
- if isscalar(desired) != isscalar(actual):
+ # dtypes compare equal to strings, but unlike strings aren't scalars,
+ # so we need to exclude them from this check
+ if (isscalar(desired) != isscalar(actual)
+ and not (isinstance(desired, dtype)
+ or isinstance(actual, dtype))):
raise AssertionError(msg)
+ # check NaT before NaN, because isfinite errors on datetime dtypes
+ if isnat(desired) and isnat(actual):
+ if desired.dtype.kind != actual.dtype.kind:
+ # datetime64 and timedelta64 NaT should not be comparable
+ raise AssertionError(msg)
+ return
# If one of desired/actual is not finite, handle it specially here:
# check that both are nan if any is a nan, and test for equality
# otherwise
- if not (gisfinite(desired) and gisfinite(actual)):
+ elif not (gisfinite(desired) and gisfinite(actual)):
isdesnan = gisnan(desired)
isactnan = gisnan(actual)
if isdesnan or isactnan:
@@ -663,6 +678,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
def isnumber(x):
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
+ def isdatetime(x):
+ return x.dtype.char in 'mM'
+
def chk_same_position(x_id, y_id, hasval='nan'):
"""Handling nan/inf: check that x and y have the nan/inf at the same
locations."""
@@ -675,6 +693,15 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
+ def chk_same_dtype(x_dt, y_dt):
+ try:
+ assert_equal(x_dt, y_dt)
+ except AssertionError:
+ msg = build_err_msg([x, y], err_msg + '\nx and y dtype mismatch',
+ verbose=verbose, header=header,
+ names=('x', 'y'), precision=precision)
+ raise AssertionError(msg)
+
try:
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
if not cond:
@@ -712,6 +739,20 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
val = safe_comparison(x[~x_id], y[~y_id])
else:
val = safe_comparison(x, y)
+ elif isdatetime(x) and isdatetime(y):
+ x_isnat, y_isnat = (x != x), (y != y)
+
+ if any(x_isnat) or any(y_isnat):
+ # cannot mix timedelta64/datetime64 NaT
+ chk_same_dtype(x.dtype, y.dtype)
+ chk_same_position(x_isnat, y_isnat, hasval='nat')
+
+ if all(x_isnat):
+ return
+ if any(x_isnat):
+ val = safe_comparison(x[~x_isnat], y[~y_isnat])
+ else:
+ val = safe_comparison(x, y)
else:
val = safe_comparison(x, y)
@@ -1826,7 +1867,7 @@ def temppath(*args, **kwargs):
parameters are the same as for tempfile.mkstemp and are passed directly
to that function. The underlying file is removed when the context is
exited, so it should be closed at that time.
-
+
Windows does not allow a temporary file to be opened if it is already
open, so the underlying file must be closed after opening before it
can be opened again.