diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/arraytypes.c.src | 53 | ||||
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 29 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 44 |
3 files changed, 96 insertions, 30 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index bffcc26a6..ce7b61287 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -2987,18 +2987,16 @@ BOOL_argmin(npy_bool *ip, npy_intp n, npy_intp *min_ind, * #fname = BYTE, UBYTE, SHORT, USHORT, INT, UINT, * LONG, ULONG, LONGLONG, ULONGLONG, * HALF, FLOAT, DOUBLE, LONGDOUBLE, - * CFLOAT, CDOUBLE, CLONGDOUBLE, - * DATETIME, TIMEDELTA# + * CFLOAT, CDOUBLE, CLONGDOUBLE# * #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint, * npy_long, npy_ulong, npy_longlong, npy_ulonglong, * npy_half, npy_float, npy_double, npy_longdouble, - * npy_float, npy_double, npy_longdouble, - * npy_datetime, npy_timedelta# - * #isfloat = 0*10, 1*7, 0*2# - * #isnan = nop*10, npy_half_isnan, npy_isnan*6, nop*2# - * #le = _LESS_THAN_OR_EQUAL*10, npy_half_le, _LESS_THAN_OR_EQUAL*8# - * #iscomplex = 0*14, 1*3, 0*2# - * #incr = ip++*14, ip+=2*3, ip++*2# + * npy_float, npy_double, npy_longdouble# + * #isfloat = 0*10, 1*7# + * #isnan = nop*10, npy_half_isnan, npy_isnan*6# + * #le = _LESS_THAN_OR_EQUAL*10, npy_half_le, _LESS_THAN_OR_EQUAL*6# + * #iscomplex = 0*14, 1*3# + * #incr = ip++*14, ip+=2*3# */ static int @fname@_argmin(@type@ *ip, npy_intp n, npy_intp *min_ind, @@ -3062,6 +3060,43 @@ static int #undef _LESS_THAN_OR_EQUAL +/**begin repeat + * + * #fname = DATETIME, TIMEDELTA# + * #type = npy_datetime, npy_timedelta# + */ +static int +@fname@_argmin(@type@ *ip, npy_intp n, npy_intp *min_ind, + PyArrayObject *NPY_UNUSED(aip)) +{ + /* NPY_DATETIME_NAT is smaller than every other value, we skip + * it for consistency with min(). + */ + npy_intp i; + @type@ mp = NPY_DATETIME_NAT; + + i = 0; + while (i < n && mp == NPY_DATETIME_NAT) { + mp = ip[i]; + i++; + } + if (i == n) { + /* All NaTs: return 0 */ + *min_ind = 0; + return 0; + } + *min_ind = i - 1; + for (; i < n; i++) { + if (mp > ip[i] && ip[i] != NPY_DATETIME_NAT) { + mp = ip[i]; + *min_ind = i; + } + } + return 0; +} + +/**end repeat**/ + static int OBJECT_argmax(PyObject **ip, npy_intp n, npy_intp *max_ind, PyArrayObject *NPY_UNUSED(aip)) diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index d0fd0e43b..21e36ee2f 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -1141,26 +1141,17 @@ NPY_NO_EXPORT void NPY_NO_EXPORT void @TYPE@_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)) { - if (IS_BINARY_REDUCE) { - BINARY_REDUCE_LOOP(@type@) { - const @type@ in2 = *(@type@ *)ip2; - io1 = (io1 @OP@ in2 || in2 == NPY_DATETIME_NAT) ? io1 : in2; + BINARY_LOOP { + const @type@ in1 = *(@type@ *)ip1; + const @type@ in2 = *(@type@ *)ip2; + if (in1 == NPY_DATETIME_NAT) { + *((@type@ *)op1) = in2; } - *((@type@ *)iop1) = io1; - } - else { - BINARY_LOOP { - const @type@ in1 = *(@type@ *)ip1; - const @type@ in2 = *(@type@ *)ip2; - if (in1 == NPY_DATETIME_NAT) { - *((@type@ *)op1) = in2; - } - else if (in2 == NPY_DATETIME_NAT) { - *((@type@ *)op1) = in1; - } - else { - *((@type@ *)op1) = (in1 @OP@ in2) ? in1 : in2; - } + else if (in2 == NPY_DATETIME_NAT) { + *((@type@ *)op1) = in1; + } + else { + *((@type@ *)op1) = (in1 @OP@ in2) ? in1 : in2; } } } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 9822d7dfc..34045b4a4 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2681,12 +2681,24 @@ class TestArgmax(TestCase): np.datetime64('2015-11-20T12:20:59'), np.datetime64('1932-09-23T10:10:13'), np.datetime64('2014-10-10T03:50:30')], 3), + # Assorted tests with NaTs + ([np.datetime64('NaT'), + np.datetime64('NaT'), + np.datetime64('2010-01-03T05:14:12'), + np.datetime64('NaT'), + np.datetime64('2015-09-23T10:10:13'), + np.datetime64('1932-10-10T03:50:30')], 4), ([np.datetime64('2059-03-14T12:43:12'), np.datetime64('1996-09-21T14:43:15'), - np.datetime64('2001-10-15T04:10:32'), + np.datetime64('NaT'), np.datetime64('2022-12-25T16:02:16'), np.datetime64('1963-10-04T03:14:12'), np.datetime64('2013-05-08T18:15:23')], 0), + ([np.timedelta64(2, 's'), + np.timedelta64(1, 's'), + np.timedelta64('NaT', 's'), + np.timedelta64(3, 's')], 3), + ([np.timedelta64('NaT', 's')] * 3, 0), ([timedelta(days=5, seconds=14), timedelta(days=2, seconds=35), timedelta(days=-1, seconds=23)], 0), @@ -2793,12 +2805,24 @@ class TestArgmin(TestCase): np.datetime64('2014-11-20T12:20:59'), np.datetime64('2015-09-23T10:10:13'), np.datetime64('1932-10-10T03:50:30')], 5), + # Assorted tests with NaTs + ([np.datetime64('NaT'), + np.datetime64('NaT'), + np.datetime64('2010-01-03T05:14:12'), + np.datetime64('NaT'), + np.datetime64('2015-09-23T10:10:13'), + np.datetime64('1932-10-10T03:50:30')], 5), ([np.datetime64('2059-03-14T12:43:12'), np.datetime64('1996-09-21T14:43:15'), - np.datetime64('2001-10-15T04:10:32'), + np.datetime64('NaT'), np.datetime64('2022-12-25T16:02:16'), np.datetime64('1963-10-04T03:14:12'), np.datetime64('2013-05-08T18:15:23')], 4), + ([np.timedelta64(2, 's'), + np.timedelta64(1, 's'), + np.timedelta64('NaT', 's'), + np.timedelta64(3, 's')], 1), + ([np.timedelta64('NaT', 's')] * 3, 0), ([timedelta(days=5, seconds=14), timedelta(days=2, seconds=35), timedelta(days=-1, seconds=23)], 2), @@ -2887,6 +2911,7 @@ class TestArgmin(TestCase): class TestMinMax(TestCase): + def test_scalar(self): assert_raises(ValueError, np.amax, 1, 1) assert_raises(ValueError, np.amin, 1, 1) @@ -2900,6 +2925,21 @@ class TestMinMax(TestCase): assert_raises(ValueError, np.amax, [1, 2, 3], 1000) assert_equal(np.amax([[1, 2, 3]], axis=1), 3) + def test_datetime(self): + # NaTs are ignored + for dtype in ('m8[s]', 'm8[Y]'): + a = np.arange(10).astype(dtype) + a[3] = 'NaT' + assert_equal(np.amin(a), a[0]) + assert_equal(np.amax(a), a[9]) + a[0] = 'NaT' + assert_equal(np.amin(a), a[1]) + assert_equal(np.amax(a), a[9]) + a.fill('NaT') + assert_equal(np.amin(a), a[0]) + assert_equal(np.amax(a), a[0]) + + class TestNewaxis(TestCase): def test_basic(self): sk = array([0, -0.1, 0.1]) |