summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-05-04 16:33:27 +0200
committerSebastian Berg <sebastianb@nvidia.com>2023-05-04 16:33:27 +0200
commitec8d5db302c0e8597feb058f58863d5e9a6554c1 (patch)
tree6edf099a4deebdc5ab86fdad314a892d3db7db7b /numpy/core/src
parentc37a577c9df74e29c97a7bb010de0b37f83870bb (diff)
downloadnumpy-ec8d5db302c0e8597feb058f58863d5e9a6554c1.tar.gz
ENH: Make signed/unsigned integer comparisons exact
This makes comparisons between signed and unsigned integers exact by special-casing promotion in comparison to never promote integers to floats, but rather promote them to uint64 or int64 and use a specific loop for that purpose. This is a bit lazy, it doesn't make the scalar paths fast (they never were though) nor does it try to vectorize the loop. Thus, for cases that are not int64/uint64 already and require a cast in either case, it should be a bit slower. OTOH, it was never really fast and the int64/uint64 mix is probably faster since it avoids casting. --- Now... the reason I was looking into this was, that I had hoped it would help with NEP 50/weak scalar typing to allow: uint64(1) < -1 # annoying that it fails with NEP 50 but, it doesn't actually, because if I use int64 for the -1 then very large numbers would be a problem... I could probably(?) add a *specific* "Python integer" ArrayMethod for comparisons and that could pick `object` dtype and thus get the original Python object (the loop could then in practice assume a scalar value). --- In either case, this works, and unless we worry about keeping the behavior we probably might as well do this. (Potentially with follow-ups to speed it up.)
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/umath/loops.c.src42
-rw-r--r--numpy/core/src/umath/loops.h.src11
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.c33
3 files changed, 76 insertions, 10 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index 397ebaca2..97a74b425 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -545,6 +545,48 @@ NPY_NO_EXPORT void
/**end repeat1**/
/**end repeat**/
+
+/*
+ * NOTE: It may be nice to vectorize these, OTOH, these are still faster
+ * than the cast we used to do.
+ */
+
+/**begin repeat
+ * #kind = equal, not_equal, less, less_equal, greater, greater_equal#
+ * #OP = ==, !=, <, <=, >, >=#
+ */
+NPY_NO_EXPORT void
+LONGLONG_Qq_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
+{
+ BINARY_LOOP {
+ const npy_ulonglong in1 = *(npy_ulonglong *)ip1;
+ const npy_longlong in2 = *(npy_longlong *)ip2;
+ if (in2 < 0) {
+ *(npy_bool *)op1 = 0 @OP@ in2;
+ }
+ else {
+ *(npy_bool *)op1 = in1 @OP@ (npy_ulonglong)in2;
+ }
+ }
+}
+
+NPY_NO_EXPORT void
+LONGLONG_qQ_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
+{
+ BINARY_LOOP {
+ const npy_longlong in1 = *(npy_longlong *)ip1;
+ const npy_ulonglong in2 = *(npy_ulonglong *)ip2;
+ if (in1 < 0) {
+ *(npy_bool *)op1 = in1 @OP@ 0;
+ }
+ else {
+ *(npy_bool *)op1 = (npy_ulonglong)in1 @OP@ in2;
+ }
+ }
+}
+/**end repeat**/
+
+
/*
*****************************************************************************
** DATETIME LOOPS **
diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src
index ab54c1966..cce73aff8 100644
--- a/numpy/core/src/umath/loops.h.src
+++ b/numpy/core/src/umath/loops.h.src
@@ -211,6 +211,17 @@ NPY_NO_EXPORT void
/**end repeat1**/
/**end repeat**/
+/**begin repeat
+ * #kind = equal, not_equal, less, less_equal, greater, greater_equal#
+ * #OP = ==, !=, <, <=, >, >=#
+ */
+NPY_NO_EXPORT void
+LONGLONG_Qq_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
+NPY_NO_EXPORT void
+LONGLONG_qQ_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
+
+/**end repeat**/
+
#ifndef NPY_DISABLE_OPTIMIZATION
#include "loops_unary.dispatch.h"
diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c
index 12187d059..decd26580 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.c
+++ b/numpy/core/src/umath/ufunc_type_resolution.c
@@ -381,8 +381,28 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc,
if (out_dtypes[0] == NULL) {
return -1;
}
- out_dtypes[1] = out_dtypes[0];
- Py_INCREF(out_dtypes[1]);
+ if (PyArray_ISINTEGER(operands[0])
+ && PyArray_ISINTEGER(operands[1])
+ && !PyDataType_ISINTEGER(out_dtypes[0])) {
+ /*
+ * NumPy promotion allows uint+int to go to float, avoid it
+ * (input must have been a mix of signed and unsigned)
+ */
+ if (PyArray_ISSIGNED(operands[0])) {
+ Py_SETREF(out_dtypes[0], PyArray_DescrFromType(NPY_LONGLONG));
+ out_dtypes[1] = PyArray_DescrFromType(NPY_ULONGLONG);
+ Py_INCREF(out_dtypes[1]);
+ }
+ else {
+ Py_SETREF(out_dtypes[0], PyArray_DescrFromType(NPY_ULONGLONG));
+ out_dtypes[1] = PyArray_DescrFromType(NPY_LONGLONG);
+ Py_INCREF(out_dtypes[1]);
+ }
+ }
+ else {
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ }
}
else {
/* Not doing anything will lead to a loop no found error. */
@@ -398,15 +418,8 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc,
operands, type_tup, out_dtypes);
}
- /* Output type is always boolean */
+ /* Output type is always boolean (cannot fail for builtins) */
out_dtypes[2] = PyArray_DescrFromType(NPY_BOOL);
- if (out_dtypes[2] == NULL) {
- for (i = 0; i < 2; ++i) {
- Py_DECREF(out_dtypes[i]);
- out_dtypes[i] = NULL;
- }
- return -1;
- }
/* Check against the casting rules */
if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {