diff options
author | Sebastian Berg <sebastianb@nvidia.com> | 2023-05-04 16:33:27 +0200 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2023-05-04 16:33:27 +0200 |
commit | ec8d5db302c0e8597feb058f58863d5e9a6554c1 (patch) | |
tree | 6edf099a4deebdc5ab86fdad314a892d3db7db7b | |
parent | c37a577c9df74e29c97a7bb010de0b37f83870bb (diff) | |
download | numpy-ec8d5db302c0e8597feb058f58863d5e9a6554c1.tar.gz |
ENH: Make signed/unsigned integer comparisons exact
This makes comparisons between signed and unsigned integers exact
by special-casing promotion in comparison to never promote integers
to floats, but rather promote them to uint64 or int64 and use a
specific loop for that purpose.
This is a bit lazy, it doesn't make the scalar paths fast (they never were
though) nor does it try to vectorize the loop.
Thus, for cases that are not int64/uint64 already and require a cast in
either case, it should be a bit slower. OTOH, it was never really fast
and the int64/uint64 mix is probably faster since it avoids casting.
---
Now... the reason I was looking into this was, that I had hoped
it would help with NEP 50/weak scalar typing to allow:
uint64(1) < -1 # annoying that it fails with NEP 50
but, it doesn't actually, because if I use int64 for the -1 then very
large numbers would be a problem...
I could probably(?) add a *specific* "Python integer" ArrayMethod for comparisons
and that could pick `object` dtype and thus get the original Python object
(the loop could then in practice assume a scalar value).
---
In either case, this works, and unless we worry about keeping the behavior
we probably might as well do this.
(Potentially with follow-ups to speed it up.)
-rw-r--r-- | numpy/core/code_generators/generate_umath.py | 52 | ||||
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 42 | ||||
-rw-r--r-- | numpy/core/src/umath/loops.h.src | 11 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_type_resolution.c | 33 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 58 |
5 files changed, 173 insertions, 23 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index d0306bb72..a170f83be 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -109,6 +109,11 @@ def _check_order(types1, types2): if t2i > t1i: break + if types1 == "QQ?" and types2 == "qQ?": + # Explicitly allow this mixed case, rather than figure out what order + # is nicer or how to encode it. + return + raise TypeError( f"Input dtypes are unsorted or duplicate: {types1} and {types2}") @@ -523,49 +528,67 @@ defdict = { Ufunc(2, 1, None, docstrings.get('numpy.core.umath.greater'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', - TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]), - [TypeDescription('O', FullTypeDescr, 'OO', 'O')], + TD(bints, out='?'), + [TypeDescription('q', FullTypeDescr, 'qQ', '?'), + TypeDescription('q', FullTypeDescr, 'Qq', '?')], + TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]), TD('O', out='?'), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'greater_equal': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.greater_equal'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', - TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]), - [TypeDescription('O', FullTypeDescr, 'OO', 'O')], + TD(bints, out='?'), + [TypeDescription('q', FullTypeDescr, 'qQ', '?'), + TypeDescription('q', FullTypeDescr, 'Qq', '?')], + TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]), TD('O', out='?'), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'less': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.less'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', - TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]), - [TypeDescription('O', FullTypeDescr, 'OO', 'O')], + TD(bints, out='?'), + [TypeDescription('q', FullTypeDescr, 'qQ', '?'), + TypeDescription('q', FullTypeDescr, 'Qq', '?')], + TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]), TD('O', out='?'), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'less_equal': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.less_equal'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', - TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]), - [TypeDescription('O', FullTypeDescr, 'OO', 'O')], + TD(bints, out='?'), + [TypeDescription('q', FullTypeDescr, 'qQ', '?'), + TypeDescription('q', FullTypeDescr, 'Qq', '?')], + TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]), TD('O', out='?'), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'equal': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.equal'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', - TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]), - [TypeDescription('O', FullTypeDescr, 'OO', 'O')], + TD(bints, out='?'), + [TypeDescription('q', FullTypeDescr, 'qQ', '?'), + TypeDescription('q', FullTypeDescr, 'Qq', '?')], + TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]), TD('O', out='?'), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'not_equal': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.not_equal'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', - TD(all, out='?', dispatch=[('loops_comparison', bints+'fd')]), - [TypeDescription('O', FullTypeDescr, 'OO', 'O')], + TD(bints, out='?'), + [TypeDescription('q', FullTypeDescr, 'qQ', '?'), + TypeDescription('q', FullTypeDescr, 'Qq', '?')], + TD(inexact+times, out='?', dispatch=[('loops_comparison', bints+'fd')]), TD('O', out='?'), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'logical_and': Ufunc(2, 1, True_, @@ -1172,7 +1195,10 @@ def make_arrays(funcdict): if t.func_data is FullTypeDescr: tname = english_upper(chartoname[t.type]) datalist.append('(void *)NULL') - cfunc_fname = f"{tname}_{t.in_}_{t.out}_{cfunc_alias}" + if t.out == "?": + cfunc_fname = f"{tname}_{t.in_}_bool_{cfunc_alias}" + else: + cfunc_fname = f"{tname}_{t.in_}_{t.out}_{cfunc_alias}" elif isinstance(t.func_data, FuncNameSuffix): datalist.append('(void *)NULL') tname = english_upper(chartoname[t.type]) diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index 397ebaca2..97a74b425 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -545,6 +545,48 @@ NPY_NO_EXPORT void /**end repeat1**/ /**end repeat**/ + +/* + * NOTE: It may be nice to vectorize these, OTOH, these are still faster + * than the cast we used to do. + */ + +/**begin repeat + * #kind = equal, not_equal, less, less_equal, greater, greater_equal# + * #OP = ==, !=, <, <=, >, >=# + */ +NPY_NO_EXPORT void +LONGLONG_Qq_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) +{ + BINARY_LOOP { + const npy_ulonglong in1 = *(npy_ulonglong *)ip1; + const npy_longlong in2 = *(npy_longlong *)ip2; + if (in2 < 0) { + *(npy_bool *)op1 = 0 @OP@ in2; + } + else { + *(npy_bool *)op1 = in1 @OP@ (npy_ulonglong)in2; + } + } +} + +NPY_NO_EXPORT void +LONGLONG_qQ_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) +{ + BINARY_LOOP { + const npy_longlong in1 = *(npy_longlong *)ip1; + const npy_ulonglong in2 = *(npy_ulonglong *)ip2; + if (in1 < 0) { + *(npy_bool *)op1 = in1 @OP@ 0; + } + else { + *(npy_bool *)op1 = (npy_ulonglong)in1 @OP@ in2; + } + } +} +/**end repeat**/ + + /* ***************************************************************************** ** DATETIME LOOPS ** diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src index ab54c1966..cce73aff8 100644 --- a/numpy/core/src/umath/loops.h.src +++ b/numpy/core/src/umath/loops.h.src @@ -211,6 +211,17 @@ NPY_NO_EXPORT void /**end repeat1**/ /**end repeat**/ +/**begin repeat + * #kind = equal, not_equal, less, less_equal, greater, greater_equal# + * #OP = ==, !=, <, <=, >, >=# + */ +NPY_NO_EXPORT void +LONGLONG_Qq_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); +NPY_NO_EXPORT void +LONGLONG_qQ_bool_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); + +/**end repeat**/ + #ifndef NPY_DISABLE_OPTIMIZATION #include "loops_unary.dispatch.h" diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c index 12187d059..decd26580 100644 --- a/numpy/core/src/umath/ufunc_type_resolution.c +++ b/numpy/core/src/umath/ufunc_type_resolution.c @@ -381,8 +381,28 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc, if (out_dtypes[0] == NULL) { return -1; } - out_dtypes[1] = out_dtypes[0]; - Py_INCREF(out_dtypes[1]); + if (PyArray_ISINTEGER(operands[0]) + && PyArray_ISINTEGER(operands[1]) + && !PyDataType_ISINTEGER(out_dtypes[0])) { + /* + * NumPy promotion allows uint+int to go to float, avoid it + * (input must have been a mix of signed and unsigned) + */ + if (PyArray_ISSIGNED(operands[0])) { + Py_SETREF(out_dtypes[0], PyArray_DescrFromType(NPY_LONGLONG)); + out_dtypes[1] = PyArray_DescrFromType(NPY_ULONGLONG); + Py_INCREF(out_dtypes[1]); + } + else { + Py_SETREF(out_dtypes[0], PyArray_DescrFromType(NPY_ULONGLONG)); + out_dtypes[1] = PyArray_DescrFromType(NPY_LONGLONG); + Py_INCREF(out_dtypes[1]); + } + } + else { + out_dtypes[1] = out_dtypes[0]; + Py_INCREF(out_dtypes[1]); + } } else { /* Not doing anything will lead to a loop no found error. */ @@ -398,15 +418,8 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc, operands, type_tup, out_dtypes); } - /* Output type is always boolean */ + /* Output type is always boolean (cannot fail for builtins) */ out_dtypes[2] = PyArray_DescrFromType(NPY_BOOL); - if (out_dtypes[2] == NULL) { - for (i = 0; i < 2; ++i) { - Py_DECREF(out_dtypes[i]); - out_dtypes[i] = NULL; - } - return -1; - } /* Check against the casting rules */ if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) { diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index b4f8d0c69..9e3fe387b 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -369,6 +369,64 @@ class TestComparisons: with pytest.raises(TypeError, match="No loop matching"): np.equal(1, 1, sig=(None, None, "l")) + @pytest.mark.parametrize("dtypes", ["qQ", "Qq"]) + @pytest.mark.parametrize('py_comp, np_comp', [ + (operator.lt, np.less), + (operator.le, np.less_equal), + (operator.gt, np.greater), + (operator.ge, np.greater_equal), + (operator.eq, np.equal), + (operator.ne, np.not_equal) + ]) + @pytest.mark.parametrize("vals", [(2**60, 2**60+1), (2**60+1, 2**60)]) + def test_large_integer_direct_comparison( + self, dtypes, py_comp, np_comp, vals): + # Note that float(2**60) + 1 == float(2**60). + a1 = np.array([2**60], dtype=dtypes[0]) + a2 = np.array([2**60 + 1], dtype=dtypes[1]) + expected = py_comp(2**60, 2**60+1) + + assert py_comp(a1, a2) == expected + assert np_comp(a1, a2) == expected + # Also check the scalars: + s1 = a1[0] + s2 = a2[0] + assert isinstance(s1, np.integer) + assert isinstance(s2, np.integer) + # The Python operator here is mainly interesting: + assert py_comp(s1, s2) == expected + assert np_comp(s1, s2) == expected + + @pytest.mark.parametrize("dtype", np.typecodes['UnsignedInteger']) + @pytest.mark.parametrize('py_comp_func, np_comp_func', [ + (operator.lt, np.less), + (operator.le, np.less_equal), + (operator.gt, np.greater), + (operator.ge, np.greater_equal), + (operator.eq, np.equal), + (operator.ne, np.not_equal) + ]) + @pytest.mark.parametrize("flip", [True, False]) + def test_unsigned_signed_direct_comparison( + self, dtype, py_comp_func, np_comp_func, flip): + if flip: + py_comp = lambda x, y: py_comp_func(y, x) + np_comp = lambda x, y: np_comp_func(y, x) + else: + py_comp = py_comp_func + np_comp = np_comp_func + + arr = np.array([np.iinfo(dtype).max], dtype=dtype) + expected = py_comp(int(arr[0]), -1) + + assert py_comp(arr, -1) == expected + assert np_comp(arr, -1) == expected + scalar = arr[0] + assert isinstance(scalar, np.integer) + # The Python operator here is mainly interesting: + assert py_comp(scalar, -1) == expected + assert np_comp(scalar, -1) == expected + class TestAdd: def test_reduce_alignment(self): |