summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2022-02-26 12:39:19 -0700
committerGitHub <noreply@github.com>2022-02-26 12:39:19 -0700
commit3f7d9d8f32d5ca472335ffcf501a0d3f67a9352f (patch)
treed6aab99f1fe5ef17ba6d55b0f436fd3902208c04 /numpy
parentd438c4c76270d83620d2ae71d794c8ca91df0445 (diff)
parent454ffba2a54449812c442852b31ec7026aff4d0f (diff)
downloadnumpy-3f7d9d8f32d5ca472335ffcf501a0d3f67a9352f.tar.gz
Merge pull request #21113 from seberg/fix-numba
BUG: Fix numba DUFuncs added loops getting picked up
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/dispatching.c34
1 files changed, 34 insertions, 0 deletions
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c
index c3f0e1e67..b8f102b3d 100644
--- a/numpy/core/src/umath/dispatching.c
+++ b/numpy/core/src/umath/dispatching.c
@@ -746,6 +746,40 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
}
info = promote_and_get_info_and_ufuncimpl(ufunc,
ops, signature, new_op_dtypes, NPY_FALSE);
+ if (info == NULL) {
+ /*
+ * NOTE: This block exists solely to support numba's DUFuncs which add
+ * new loops dynamically, so our list may get outdated. Thus, we
+ * have to make sure that the loop exists.
+ *
+ * Before adding a new loop, ensure that it actually exists. There
+ * is a tiny chance that this would not work, but it would require an
+ * extension additionally have a custom loop getter.
+ * This check should ensure a the right error message, but in principle
+ * we could try to call the loop getter here.
+ */
+ char *types = ufunc->types;
+ npy_bool loop_exists = NPY_FALSE;
+ for (int i = 0; i < ufunc->ntypes; ++i) {
+ loop_exists = NPY_TRUE; /* assume it exists, break if not */
+ for (int j = 0; j < ufunc->nargs; ++j) {
+ if (types[j] != new_op_dtypes[j]->type_num) {
+ loop_exists = NPY_FALSE;
+ break;
+ }
+ }
+ if (loop_exists) {
+ break;
+ }
+ types += ufunc->nargs;
+ }
+
+ if (loop_exists) {
+ info = add_and_return_legacy_wrapping_ufunc_loop(
+ ufunc, new_op_dtypes, 0);
+ }
+ }
+
for (int i = 0; i < ufunc->nargs; i++) {
Py_XDECREF(new_op_dtypes[i]);
}