summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-05-04 16:33:27 +0200
committerSebastian Berg <sebastianb@nvidia.com>2023-05-04 16:33:27 +0200
commitec8d5db302c0e8597feb058f58863d5e9a6554c1 (patch)
tree6edf099a4deebdc5ab86fdad314a892d3db7db7b
parentc37a577c9df74e29c97a7bb010de0b37f83870bb (diff)
downloadnumpy-ec8d5db302c0e8597feb058f58863d5e9a6554c1.tar.gz
ENH: Make signed/unsigned integer comparisons exact
This makes comparisons between signed and unsigned integers exact by special-casing promotion in comparison to never promote integers to floats, but rather promote them to uint64 or int64 and use a specific loop for that purpose. This is a bit lazy, it doesn't make the scalar paths fast (they never were though) nor does it try to vectorize the loop. Thus, for cases that are not int64/uint64 already and require a cast in either case, it should be a bit slower. OTOH, it was never really fast and the int64/uint64 mix is probably faster since it avoids casting. --- Now... the reason I was looking into this was, that I had hoped it would help with NEP 50/weak scalar typing to allow: uint64(1) < -1 # annoying that it fails with NEP 50 but, it doesn't actually, because if I use int64 for the -1 then very large numbers would be a problem... I could probably(?) add a *specific* "Python integer" ArrayMethod for comparisons and that could pick `object` dtype and thus get the original Python object (the loop could then in practice assume a scalar value). --- In either case, this works, and unless we worry about keeping the behavior we probably might as well do this. (Potentially with follow-ups to speed it up.)
-rw-r--r--numpy/core/code_generators/generate_umath.py52
-rw-r--r--numpy/core/src/umath/loops.c.src42
-rw-r--r--numpy/core/src/umath/loops.h.src11
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.c33
-rw-r--r--numpy/core/tests/test_umath.py58
5 files changed, 173 insertions, 23 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index d0306bb72..a170f83be 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -109,6 +109,11 @@ def _check_order(types1, types2):
if t2i > t1i:
break
+ if types1 == "QQ?" and types2 == "qQ?":
+ # Explicitly allow this mixed case, rather than figure out what order
+ # is nicer or how to encode it.
+ return
+
raise TypeError(
f"Input dtypes are unsorted or duplicate: {types1} and {types2}")
@@ -523,49 +528,67 @@ defdict = {
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.greater'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
- TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]),
- [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
+ TD(bints, out='?'),
+ [TypeDescription('q', FullTypeDescr, 'qQ', '?'),
+ TypeDescription('q', FullTypeDescr, 'Qq', '?')],
+ TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]),
TD('O', out='?'),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'greater_equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.greater_equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
- TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]),
- [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
+ TD(bints, out='?'),
+ [TypeDescription('q', FullTypeDescr, 'qQ', '?'),
+ TypeDescription('q', FullTypeDescr, 'Qq', '?')],
+ TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]),
TD('O', out='?'),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'less':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.less'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
- TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]),
- [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
+ TD(bints, out='?'),
+ [TypeDescription('q', FullTypeDescr, 'qQ', '?'),
+ TypeDescription('q', FullTypeDescr, 'Qq', '?')],
+ TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]),
TD('O', out='?'),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'less_equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.less_equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
- TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]),
- [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
+ TD(bints, out='?'),
+ [TypeDescription('q', FullTypeDescr, 'qQ', '?'),
+ TypeDescription('q', FullTypeDescr, 'Qq', '?')],
+ TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]),
TD('O', out='?'),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
- TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]),
- [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
+ TD(bints, out='?'),
+ [TypeDescription('q', FullTypeDescr, 'qQ', '?'),
+ TypeDescription('q', FullTypeDescr, 'Qq', '?')],
+ TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]),
TD('O', out='?'),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'not_equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.not_equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
- TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]),
- [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
+ TD(bints, out='?'),
+ [TypeDescription('q', FullTypeDescr, 'qQ', '?'),
+ TypeDescription('q', FullTypeDescr, 'Qq', '?')],
+ TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]),
TD('O', out='?'),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'logical_and':
Ufunc(2, 1, True_,
@@ -1172,7 +1195,10 @@ def make_arrays(funcdict):
if t.func_data is FullTypeDescr:
tname = english_upper(chartoname[t.type])
datalist.append('(void *)NULL')
- cfunc_fname = f"{tname}_{t.in_}_{t.out}_{cfunc_alias}"
+ if t.out == "?":
+ cfunc_fname = f"{tname}_{t.in_}_bool_{cfunc_alias}"
+ else:
+ cfunc_fname = f"{tname}_{t.in_}_{t.out}_{cfunc_alias}"
elif isinstance(t.func_data, FuncNameSuffix):
datalist.append('(void *)NULL')
tname = english_upper(chartoname[t.type])
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index 397ebaca2..97a74b425 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -545,6 +545,48 @@ NPY_NO_EXPORT void
/**end repeat1**/
/**end repeat**/
+
+/*
+ * NOTE: It may be nice to vectorize these, OTOH, these are still faster
+ * than the cast we used to do.
+ */
+
+/**begin repeat
+ * #kind = equal, not_equal, less, less_equal, greater, greater_equal#
+ * #OP = ==, !=, <, <=, >, >=#
+ */
+NPY_NO_EXPORT void
+LONGLONG_Qq_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
+{
+ BINARY_LOOP {
+ const npy_ulonglong in1 = *(npy_ulonglong *)ip1;
+ const npy_longlong in2 = *(npy_longlong *)ip2;
+ if (in2 < 0) {
+ *(npy_bool *)op1 = 0 @OP@ in2;
+ }
+ else {
+ *(npy_bool *)op1 = in1 @OP@ (npy_ulonglong)in2;
+ }
+ }
+}
+
+NPY_NO_EXPORT void
+LONGLONG_qQ_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
+{
+ BINARY_LOOP {
+ const npy_longlong in1 = *(npy_longlong *)ip1;
+ const npy_ulonglong in2 = *(npy_ulonglong *)ip2;
+ if (in1 < 0) {
+ *(npy_bool *)op1 = in1 @OP@ 0;
+ }
+ else {
+ *(npy_bool *)op1 = (npy_ulonglong)in1 @OP@ in2;
+ }
+ }
+}
+/**end repeat**/
+
+
/*
*****************************************************************************
** DATETIME LOOPS **
diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src
index ab54c1966..cce73aff8 100644
--- a/numpy/core/src/umath/loops.h.src
+++ b/numpy/core/src/umath/loops.h.src
@@ -211,6 +211,17 @@ NPY_NO_EXPORT void
/**end repeat1**/
/**end repeat**/
+/**begin repeat
+ * #kind = equal, not_equal, less, less_equal, greater, greater_equal#
+ * #OP = ==, !=, <, <=, >, >=#
+ */
+NPY_NO_EXPORT void
+LONGLONG_Qq_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
+NPY_NO_EXPORT void
+LONGLONG_qQ_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func));
+
+/**end repeat**/
+
#ifndef NPY_DISABLE_OPTIMIZATION
#include "loops_unary.dispatch.h"
diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c
index 12187d059..decd26580 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.c
+++ b/numpy/core/src/umath/ufunc_type_resolution.c
@@ -381,8 +381,28 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc,
if (out_dtypes[0] == NULL) {
return -1;
}
- out_dtypes[1] = out_dtypes[0];
- Py_INCREF(out_dtypes[1]);
+ if (PyArray_ISINTEGER(operands[0])
+ && PyArray_ISINTEGER(operands[1])
+ && !PyDataType_ISINTEGER(out_dtypes[0])) {
+ /*
+ * NumPy promotion allows uint+int to go to float, avoid it
+ * (input must have been a mix of signed and unsigned)
+ */
+ if (PyArray_ISSIGNED(operands[0])) {
+ Py_SETREF(out_dtypes[0], PyArray_DescrFromType(NPY_LONGLONG));
+ out_dtypes[1] = PyArray_DescrFromType(NPY_ULONGLONG);
+ Py_INCREF(out_dtypes[1]);
+ }
+ else {
+ Py_SETREF(out_dtypes[0], PyArray_DescrFromType(NPY_ULONGLONG));
+ out_dtypes[1] = PyArray_DescrFromType(NPY_LONGLONG);
+ Py_INCREF(out_dtypes[1]);
+ }
+ }
+ else {
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ }
}
else {
/* Not doing anything will lead to a loop no found error. */
@@ -398,15 +418,8 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc,
operands, type_tup, out_dtypes);
}
- /* Output type is always boolean */
+ /* Output type is always boolean (cannot fail for builtins) */
out_dtypes[2] = PyArray_DescrFromType(NPY_BOOL);
- if (out_dtypes[2] == NULL) {
- for (i = 0; i < 2; ++i) {
- Py_DECREF(out_dtypes[i]);
- out_dtypes[i] = NULL;
- }
- return -1;
- }
/* Check against the casting rules */
if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index b4f8d0c69..9e3fe387b 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -369,6 +369,64 @@ class TestComparisons:
with pytest.raises(TypeError, match="No loop matching"):
np.equal(1, 1, sig=(None, None, "l"))
+ @pytest.mark.parametrize("dtypes", ["qQ", "Qq"])
+ @pytest.mark.parametrize('py_comp, np_comp', [
+ (operator.lt, np.less),
+ (operator.le, np.less_equal),
+ (operator.gt, np.greater),
+ (operator.ge, np.greater_equal),
+ (operator.eq, np.equal),
+ (operator.ne, np.not_equal)
+ ])
+ @pytest.mark.parametrize("vals", [(2**60, 2**60+1), (2**60+1, 2**60)])
+ def test_large_integer_direct_comparison(
+ self, dtypes, py_comp, np_comp, vals):
+ # Note that float(2**60) + 1 == float(2**60).
+ a1 = np.array([2**60], dtype=dtypes[0])
+ a2 = np.array([2**60 + 1], dtype=dtypes[1])
+ expected = py_comp(2**60, 2**60+1)
+
+ assert py_comp(a1, a2) == expected
+ assert np_comp(a1, a2) == expected
+ # Also check the scalars:
+ s1 = a1[0]
+ s2 = a2[0]
+ assert isinstance(s1, np.integer)
+ assert isinstance(s2, np.integer)
+ # The Python operator here is mainly interesting:
+ assert py_comp(s1, s2) == expected
+ assert np_comp(s1, s2) == expected
+
+ @pytest.mark.parametrize("dtype", np.typecodes['UnsignedInteger'])
+ @pytest.mark.parametrize('py_comp_func, np_comp_func', [
+ (operator.lt, np.less),
+ (operator.le, np.less_equal),
+ (operator.gt, np.greater),
+ (operator.ge, np.greater_equal),
+ (operator.eq, np.equal),
+ (operator.ne, np.not_equal)
+ ])
+ @pytest.mark.parametrize("flip", [True, False])
+ def test_unsigned_signed_direct_comparison(
+ self, dtype, py_comp_func, np_comp_func, flip):
+ if flip:
+ py_comp = lambda x, y: py_comp_func(y, x)
+ np_comp = lambda x, y: np_comp_func(y, x)
+ else:
+ py_comp = py_comp_func
+ np_comp = np_comp_func
+
+ arr = np.array([np.iinfo(dtype).max], dtype=dtype)
+ expected = py_comp(int(arr[0]), -1)
+
+ assert py_comp(arr, -1) == expected
+ assert np_comp(arr, -1) == expected
+ scalar = arr[0]
+ assert isinstance(scalar, np.integer)
+ # The Python operator here is mainly interesting:
+ assert py_comp(scalar, -1) == expected
+ assert np_comp(scalar, -1) == expected
+
class TestAdd:
def test_reduce_alignment(self):