summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2023-05-06 20:50:09 -0400
committerGitHub <noreply@github.com>2023-05-06 20:50:09 -0400
commit276cc995c5e3860e226c072c9264012d2132c87a (patch)
treec4d5a4698526d13fad8dfa50664e240ebfc98c04 /numpy/core/src
parentc942d65ea246d582b009c5270de949cb24018e13 (diff)
parent6c394e3d485f3a522d6e7242a577ee04a9126e0b (diff)
downloadnumpy-276cc995c5e3860e226c072c9264012d2132c87a.tar.gz
Merge pull request #23713 from seberg/uint-int-comparisons
ENH: Make signed/unsigned integer comparisons exact
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/umath/loops.c.src42
-rw-r--r--numpy/core/src/umath/loops.h.src11
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.c33
3 files changed, 76 insertions, 10 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index 397ebaca2..97a74b425 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -545,6 +545,48 @@ NPY_NO_EXPORT void
/**end repeat1**/
/**end repeat**/
+
+/*
+ * NOTE: It may be nice to vectorize these, OTOH, these are still faster
+ * than the cast we used to do.
+ */
+
+/**begin repeat
+ * #kind = equal, not_equal, less, less_equal, greater, greater_equal#
+ * #OP = ==, !=, <, <=, >, >=#
+ */
+NPY_NO_EXPORT void
+LONGLONG_Qq_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
+{
+ BINARY_LOOP {
+ const npy_ulonglong in1 = *(npy_ulonglong *)ip1;
+ const npy_longlong in2 = *(npy_longlong *)ip2;
+ if (in2 < 0) {
+ *(npy_bool *)op1 = 0 @OP@ in2;
+ }
+ else {
+ *(npy_bool *)op1 = in1 @OP@ (npy_ulonglong)in2;
+ }
+ }
+}
+
+NPY_NO_EXPORT void
+LONGLONG_qQ_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
+{
+ BINARY_LOOP {
+ const npy_longlong in1 = *(npy_longlong *)ip1;
+ const npy_ulonglong in2 = *(npy_ulonglong *)ip2;
+ if (in1 < 0) {
+ *(npy_bool *)op1 = in1 @OP@ 0;
+ }
+ else {
+ *(npy_bool *)op1 = (npy_ulonglong)in1 @OP@ in2;
+ }
+ }
+}
+/**end repeat**/
+
+
/*
*****************************************************************************
** DATETIME LOOPS **
diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src
index ab54c1966..cce73aff8 100644
--- a/numpy/core/src/umath/loops.h.src
+++ b/numpy/core/src/umath/loops.h.src
@@ -211,6 +211,17 @@ NPY_NO_EXPORT void
/**end repeat1**/
/**end repeat**/
+/**begin repeat
+ * #kind = equal, not_equal, less, less_equal, greater, greater_equal#
+ * #OP = ==, !=, <, <=, >, >=#
+ */
+NPY_NO_EXPORT void
+LONGLONG_Qq_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
+NPY_NO_EXPORT void
+LONGLONG_qQ_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
+
+/**end repeat**/
+
#ifndef NPY_DISABLE_OPTIMIZATION
#include "loops_unary.dispatch.h"
diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c
index 12187d059..decd26580 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.c
+++ b/numpy/core/src/umath/ufunc_type_resolution.c
@@ -381,8 +381,28 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc,
if (out_dtypes[0] == NULL) {
return -1;
}
- out_dtypes[1] = out_dtypes[0];
- Py_INCREF(out_dtypes[1]);
+ if (PyArray_ISINTEGER(operands[0])
+ && PyArray_ISINTEGER(operands[1])
+ && !PyDataType_ISINTEGER(out_dtypes[0])) {
+ /*
+ * NumPy promotion allows uint+int to go to float, avoid it
+ * (input must have been a mix of signed and unsigned)
+ */
+ if (PyArray_ISSIGNED(operands[0])) {
+ Py_SETREF(out_dtypes[0], PyArray_DescrFromType(NPY_LONGLONG));
+ out_dtypes[1] = PyArray_DescrFromType(NPY_ULONGLONG);
+ Py_INCREF(out_dtypes[1]);
+ }
+ else {
+ Py_SETREF(out_dtypes[0], PyArray_DescrFromType(NPY_ULONGLONG));
+ out_dtypes[1] = PyArray_DescrFromType(NPY_LONGLONG);
+ Py_INCREF(out_dtypes[1]);
+ }
+ }
+ else {
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ }
}
else {
/* Not doing anything will lead to a loop no found error. */
@@ -398,15 +418,8 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc,
operands, type_tup, out_dtypes);
}
- /* Output type is always boolean */
+ /* Output type is always boolean (cannot fail for builtins) */
out_dtypes[2] = PyArray_DescrFromType(NPY_BOOL);
- if (out_dtypes[2] == NULL) {
- for (i = 0; i < 2; ++i) {
- Py_DECREF(out_dtypes[i]);
- out_dtypes[i] = NULL;
- }
- return -1;
- }
/* Check against the casting rules */
if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {