diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2022-02-26 12:39:19 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-26 12:39:19 -0700 |
commit | 3f7d9d8f32d5ca472335ffcf501a0d3f67a9352f (patch) | |
tree | d6aab99f1fe5ef17ba6d55b0f436fd3902208c04 /numpy | |
parent | d438c4c76270d83620d2ae71d794c8ca91df0445 (diff) | |
parent | 454ffba2a54449812c442852b31ec7026aff4d0f (diff) | |
download | numpy-3f7d9d8f32d5ca472335ffcf501a0d3f67a9352f.tar.gz |
Merge pull request #21113 from seberg/fix-numba
BUG: Fix numba DUFuncs added loops getting picked up
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/umath/dispatching.c | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c index c3f0e1e67..b8f102b3d 100644 --- a/numpy/core/src/umath/dispatching.c +++ b/numpy/core/src/umath/dispatching.c @@ -746,6 +746,40 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc, } info = promote_and_get_info_and_ufuncimpl(ufunc, ops, signature, new_op_dtypes, NPY_FALSE); + if (info == NULL) { + /* + * NOTE: This block exists solely to support numba's DUFuncs which add + * new loops dynamically, so our list may get outdated. Thus, we + * have to make sure that the loop exists. + * + * Before adding a new loop, ensure that it actually exists. There + * is a tiny chance that this would not work, but it would require an + * extension additionally have a custom loop getter. + * This check should ensure a the right error message, but in principle + * we could try to call the loop getter here. + */ + char *types = ufunc->types; + npy_bool loop_exists = NPY_FALSE; + for (int i = 0; i < ufunc->ntypes; ++i) { + loop_exists = NPY_TRUE; /* assume it exists, break if not */ + for (int j = 0; j < ufunc->nargs; ++j) { + if (types[j] != new_op_dtypes[j]->type_num) { + loop_exists = NPY_FALSE; + break; + } + } + if (loop_exists) { + break; + } + types += ufunc->nargs; + } + + if (loop_exists) { + info = add_and_return_legacy_wrapping_ufunc_loop( + ufunc, new_op_dtypes, 0); + } + } + for (int i = 0; i < ufunc->nargs; i++) { Py_XDECREF(new_op_dtypes[i]); } |