diff options
author | Sebastian Berg <sebastianb@nvidia.com> | 2023-05-04 16:33:27 +0200 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2023-05-04 16:33:27 +0200 |
commit | ec8d5db302c0e8597feb058f58863d5e9a6554c1 (patch) | |
tree | 6edf099a4deebdc5ab86fdad314a892d3db7db7b /numpy/core/src | |
parent | c37a577c9df74e29c97a7bb010de0b37f83870bb (diff) | |
download | numpy-ec8d5db302c0e8597feb058f58863d5e9a6554c1.tar.gz |
ENH: Make signed/unsigned integer comparisons exact
This makes comparisons between signed and unsigned integers exact
by special-casing promotion in comparison to never promote integers
to floats, but rather promote them to uint64 or int64 and use a
specific loop for that purpose.
This is a bit lazy, it doesn't make the scalar paths fast (they never were
though) nor does it try to vectorize the loop.
Thus, for cases that are not int64/uint64 already and require a cast in
either case, it should be a bit slower. OTOH, it was never really fast
and the int64/uint64 mix is probably faster since it avoids casting.
---
Now... the reason I was looking into this was, that I had hoped
it would help with NEP 50/weak scalar typing to allow:
uint64(1) < -1 # annoying that it fails with NEP 50
but, it doesn't actually, because if I use int64 for the -1 then very
large numbers would be a problem...
I could probably(?) add a *specific* "Python integer" ArrayMethod for comparisons
and that could pick `object` dtype and thus get the original Python object
(the loop could then in practice assume a scalar value).
---
In either case, this works, and unless we worry about keeping the behavior
we probably might as well do this.
(Potentially with follow-ups to speed it up.)
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 42 | ||||
-rw-r--r-- | numpy/core/src/umath/loops.h.src | 11 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_type_resolution.c | 33 |
3 files changed, 76 insertions, 10 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index 397ebaca2..97a74b425 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -545,6 +545,48 @@ NPY_NO_EXPORT void /**end repeat1**/ /**end repeat**/ + +/* + * NOTE: It may be nice to vectorize these, OTOH, these are still faster + * than the cast we used to do. + */ + +/**begin repeat + * #kind = equal, not_equal, less, less_equal, greater, greater_equal# + * #OP = ==, !=, <, <=, >, >=# + */ +NPY_NO_EXPORT void +LONGLONG_Qq_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) +{ + BINARY_LOOP { + const npy_ulonglong in1 = *(npy_ulonglong *)ip1; + const npy_longlong in2 = *(npy_longlong *)ip2; + if (in2 < 0) { + *(npy_bool *)op1 = 0 @OP@ in2; + } + else { + *(npy_bool *)op1 = in1 @OP@ (npy_ulonglong)in2; + } + } +} + +NPY_NO_EXPORT void +LONGLONG_qQ_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) +{ + BINARY_LOOP { + const npy_longlong in1 = *(npy_longlong *)ip1; + const npy_ulonglong in2 = *(npy_ulonglong *)ip2; + if (in1 < 0) { + *(npy_bool *)op1 = in1 @OP@ 0; + } + else { + *(npy_bool *)op1 = (npy_ulonglong)in1 @OP@ in2; + } + } +} +/**end repeat**/ + + /* ***************************************************************************** ** DATETIME LOOPS ** diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src index ab54c1966..cce73aff8 100644 --- a/numpy/core/src/umath/loops.h.src +++ b/numpy/core/src/umath/loops.h.src @@ -211,6 +211,17 @@ NPY_NO_EXPORT void /**end repeat1**/ /**end repeat**/ +/**begin repeat + * #kind = equal, not_equal, less, less_equal, greater, greater_equal# + * #OP = ==, !=, <, <=, >, >=# + */ +NPY_NO_EXPORT void +LONGLONG_Qq_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); +NPY_NO_EXPORT void +LONGLONG_qQ_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); + +/**end repeat**/ + #ifndef NPY_DISABLE_OPTIMIZATION #include "loops_unary.dispatch.h" diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c index 12187d059..decd26580 100644 --- a/numpy/core/src/umath/ufunc_type_resolution.c +++ b/numpy/core/src/umath/ufunc_type_resolution.c @@ -381,8 +381,28 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc, if (out_dtypes[0] == NULL) { return -1; } - out_dtypes[1] = out_dtypes[0]; - Py_INCREF(out_dtypes[1]); + if (PyArray_ISINTEGER(operands[0]) + && PyArray_ISINTEGER(operands[1]) + && !PyDataType_ISINTEGER(out_dtypes[0])) { + /* + * NumPy promotion allows uint+int to go to float, avoid it + * (input must have been a mix of signed and unsigned) + */ + if (PyArray_ISSIGNED(operands[0])) { + Py_SETREF(out_dtypes[0], PyArray_DescrFromType(NPY_LONGLONG)); + out_dtypes[1] = PyArray_DescrFromType(NPY_ULONGLONG); + Py_INCREF(out_dtypes[1]); + } + else { + Py_SETREF(out_dtypes[0], PyArray_DescrFromType(NPY_ULONGLONG)); + out_dtypes[1] = PyArray_DescrFromType(NPY_LONGLONG); + Py_INCREF(out_dtypes[1]); + } + } + else { + out_dtypes[1] = out_dtypes[0]; + Py_INCREF(out_dtypes[1]); + } } else { /* Not doing anything will lead to a loop no found error. */ @@ -398,15 +418,8 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc, operands, type_tup, out_dtypes); } - /* Output type is always boolean */ + /* Output type is always boolean (cannot fail for builtins) */ out_dtypes[2] = PyArray_DescrFromType(NPY_BOOL); - if (out_dtypes[2] == NULL) { - for (i = 0; i < 2; ++i) { - Py_DECREF(out_dtypes[i]); - out_dtypes[i] = NULL; - } - return -1; - } /* Check against the casting rules */ if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) { |