summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/arraytypes.c.src53
-rw-r--r--numpy/core/src/umath/loops.c.src29
-rw-r--r--numpy/core/tests/test_multiarray.py44
3 files changed, 96 insertions, 30 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src
index bffcc26a6..ce7b61287 100644
--- a/numpy/core/src/multiarray/arraytypes.c.src
+++ b/numpy/core/src/multiarray/arraytypes.c.src
@@ -2987,18 +2987,16 @@ BOOL_argmin(npy_bool *ip, npy_intp n, npy_intp *min_ind,
* #fname = BYTE, UBYTE, SHORT, USHORT, INT, UINT,
* LONG, ULONG, LONGLONG, ULONGLONG,
* HALF, FLOAT, DOUBLE, LONGDOUBLE,
- * CFLOAT, CDOUBLE, CLONGDOUBLE,
- * DATETIME, TIMEDELTA#
+ * CFLOAT, CDOUBLE, CLONGDOUBLE#
* #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
* npy_long, npy_ulong, npy_longlong, npy_ulonglong,
* npy_half, npy_float, npy_double, npy_longdouble,
- * npy_float, npy_double, npy_longdouble,
- * npy_datetime, npy_timedelta#
- * #isfloat = 0*10, 1*7, 0*2#
- * #isnan = nop*10, npy_half_isnan, npy_isnan*6, nop*2#
- * #le = _LESS_THAN_OR_EQUAL*10, npy_half_le, _LESS_THAN_OR_EQUAL*8#
- * #iscomplex = 0*14, 1*3, 0*2#
- * #incr = ip++*14, ip+=2*3, ip++*2#
+ * npy_float, npy_double, npy_longdouble#
+ * #isfloat = 0*10, 1*7#
+ * #isnan = nop*10, npy_half_isnan, npy_isnan*6#
+ * #le = _LESS_THAN_OR_EQUAL*10, npy_half_le, _LESS_THAN_OR_EQUAL*6#
+ * #iscomplex = 0*14, 1*3#
+ * #incr = ip++*14, ip+=2*3#
*/
static int
@fname@_argmin(@type@ *ip, npy_intp n, npy_intp *min_ind,
@@ -3062,6 +3060,43 @@ static int
#undef _LESS_THAN_OR_EQUAL
+/**begin repeat
+ *
+ * #fname = DATETIME, TIMEDELTA#
+ * #type = npy_datetime, npy_timedelta#
+ */
+static int
+@fname@_argmin(@type@ *ip, npy_intp n, npy_intp *min_ind,
+ PyArrayObject *NPY_UNUSED(aip))
+{
+ /* NPY_DATETIME_NAT is smaller than every other value, we skip
+ * it for consistency with min().
+ */
+ npy_intp i;
+ @type@ mp = NPY_DATETIME_NAT;
+
+ i = 0;
+ while (i < n && mp == NPY_DATETIME_NAT) {
+ mp = ip[i];
+ i++;
+ }
+ if (i == n) {
+ /* All NaTs: return 0 */
+ *min_ind = 0;
+ return 0;
+ }
+ *min_ind = i - 1;
+ for (; i < n; i++) {
+ if (mp > ip[i] && ip[i] != NPY_DATETIME_NAT) {
+ mp = ip[i];
+ *min_ind = i;
+ }
+ }
+ return 0;
+}
+
+/**end repeat**/
+
static int
OBJECT_argmax(PyObject **ip, npy_intp n, npy_intp *max_ind,
PyArrayObject *NPY_UNUSED(aip))
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index d0fd0e43b..21e36ee2f 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -1141,26 +1141,17 @@ NPY_NO_EXPORT void
NPY_NO_EXPORT void
@TYPE@_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
{
- if (IS_BINARY_REDUCE) {
- BINARY_REDUCE_LOOP(@type@) {
- const @type@ in2 = *(@type@ *)ip2;
- io1 = (io1 @OP@ in2 || in2 == NPY_DATETIME_NAT) ? io1 : in2;
+ BINARY_LOOP {
+ const @type@ in1 = *(@type@ *)ip1;
+ const @type@ in2 = *(@type@ *)ip2;
+ if (in1 == NPY_DATETIME_NAT) {
+ *((@type@ *)op1) = in2;
}
- *((@type@ *)iop1) = io1;
- }
- else {
- BINARY_LOOP {
- const @type@ in1 = *(@type@ *)ip1;
- const @type@ in2 = *(@type@ *)ip2;
- if (in1 == NPY_DATETIME_NAT) {
- *((@type@ *)op1) = in2;
- }
- else if (in2 == NPY_DATETIME_NAT) {
- *((@type@ *)op1) = in1;
- }
- else {
- *((@type@ *)op1) = (in1 @OP@ in2) ? in1 : in2;
- }
+ else if (in2 == NPY_DATETIME_NAT) {
+ *((@type@ *)op1) = in1;
+ }
+ else {
+ *((@type@ *)op1) = (in1 @OP@ in2) ? in1 : in2;
}
}
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 9822d7dfc..34045b4a4 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2681,12 +2681,24 @@ class TestArgmax(TestCase):
np.datetime64('2015-11-20T12:20:59'),
np.datetime64('1932-09-23T10:10:13'),
np.datetime64('2014-10-10T03:50:30')], 3),
+ # Assorted tests with NaTs
+ ([np.datetime64('NaT'),
+ np.datetime64('NaT'),
+ np.datetime64('2010-01-03T05:14:12'),
+ np.datetime64('NaT'),
+ np.datetime64('2015-09-23T10:10:13'),
+ np.datetime64('1932-10-10T03:50:30')], 4),
([np.datetime64('2059-03-14T12:43:12'),
np.datetime64('1996-09-21T14:43:15'),
- np.datetime64('2001-10-15T04:10:32'),
+ np.datetime64('NaT'),
np.datetime64('2022-12-25T16:02:16'),
np.datetime64('1963-10-04T03:14:12'),
np.datetime64('2013-05-08T18:15:23')], 0),
+ ([np.timedelta64(2, 's'),
+ np.timedelta64(1, 's'),
+ np.timedelta64('NaT', 's'),
+ np.timedelta64(3, 's')], 3),
+ ([np.timedelta64('NaT', 's')] * 3, 0),
([timedelta(days=5, seconds=14), timedelta(days=2, seconds=35),
timedelta(days=-1, seconds=23)], 0),
@@ -2793,12 +2805,24 @@ class TestArgmin(TestCase):
np.datetime64('2014-11-20T12:20:59'),
np.datetime64('2015-09-23T10:10:13'),
np.datetime64('1932-10-10T03:50:30')], 5),
+ # Assorted tests with NaTs
+ ([np.datetime64('NaT'),
+ np.datetime64('NaT'),
+ np.datetime64('2010-01-03T05:14:12'),
+ np.datetime64('NaT'),
+ np.datetime64('2015-09-23T10:10:13'),
+ np.datetime64('1932-10-10T03:50:30')], 5),
([np.datetime64('2059-03-14T12:43:12'),
np.datetime64('1996-09-21T14:43:15'),
- np.datetime64('2001-10-15T04:10:32'),
+ np.datetime64('NaT'),
np.datetime64('2022-12-25T16:02:16'),
np.datetime64('1963-10-04T03:14:12'),
np.datetime64('2013-05-08T18:15:23')], 4),
+ ([np.timedelta64(2, 's'),
+ np.timedelta64(1, 's'),
+ np.timedelta64('NaT', 's'),
+ np.timedelta64(3, 's')], 1),
+ ([np.timedelta64('NaT', 's')] * 3, 0),
([timedelta(days=5, seconds=14), timedelta(days=2, seconds=35),
timedelta(days=-1, seconds=23)], 2),
@@ -2887,6 +2911,7 @@ class TestArgmin(TestCase):
class TestMinMax(TestCase):
+
def test_scalar(self):
assert_raises(ValueError, np.amax, 1, 1)
assert_raises(ValueError, np.amin, 1, 1)
@@ -2900,6 +2925,21 @@ class TestMinMax(TestCase):
assert_raises(ValueError, np.amax, [1, 2, 3], 1000)
assert_equal(np.amax([[1, 2, 3]], axis=1), 3)
+ def test_datetime(self):
+ # NaTs are ignored
+ for dtype in ('m8[s]', 'm8[Y]'):
+ a = np.arange(10).astype(dtype)
+ a[3] = 'NaT'
+ assert_equal(np.amin(a), a[0])
+ assert_equal(np.amax(a), a[9])
+ a[0] = 'NaT'
+ assert_equal(np.amin(a), a[1])
+ assert_equal(np.amax(a), a[9])
+ a.fill('NaT')
+ assert_equal(np.amin(a), a[0])
+ assert_equal(np.amax(a), a[0])
+
+
class TestNewaxis(TestCase):
def test_basic(self):
sk = array([0, -0.1, 0.1])