diff options
author | Sebastian Berg <sebastianb@nvidia.com> | 2023-03-22 12:20:00 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-22 12:20:00 +0100 |
commit | b35aac2c35ccfd5efadd7f72a090c9ad99308a60 (patch) | |
tree | b1bd47eb0b5f68488379d8ea631aaa9c360a7552 /numpy/core/src | |
parent | 294c7f2c893b7e5ef783fc1cb1912d06404b452b (diff) | |
parent | f3f108d313a8b8a4f7a90fb932867f17dc48b1f6 (diff) | |
download | numpy-b35aac2c35ccfd5efadd7f72a090c9ad99308a60.tar.gz |
Merge pull request #23240 from byrdie/bugfix/ufunc_where_propagation
ENH: Allow ``where`` argument to override ``__array_ufunc__``
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 14 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 5 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.h | 1 | ||||
-rw-r--r-- | numpy/core/src/umath/override.c | 16 | ||||
-rw-r--r-- | numpy/core/src/umath/override.h | 2 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 6 |
6 files changed, 33 insertions, 11 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index f518f3a02..93b290020 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -28,6 +28,7 @@ #include "strfuncs.h" #include "array_assign.h" #include "npy_dlpack.h" +#include "multiarraymodule.h" #include "methods.h" #include "alloc.h" @@ -1102,7 +1103,7 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds) int nin, nout; PyObject *out_kwd_obj; PyObject *fast; - PyObject **in_objs, **out_objs; + PyObject **in_objs, **out_objs, *where_obj; /* check inputs */ nin = PyTuple_Size(args); @@ -1133,6 +1134,17 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds) } } Py_DECREF(out_kwd_obj); + /* check where if it exists */ + where_obj = PyDict_GetItemWithError(kwds, npy_ma_str_where); + if (where_obj == NULL) { + if (PyErr_Occurred()) { + return -1; + } + } else { + if (PyUFunc_HasOverride(where_obj)){ + return 1; + } + } return 0; } diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index e85f8affa..ac8e641b7 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -4843,6 +4843,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis1 = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis2 = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_like = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_numpy = NULL; +NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_where = NULL; static int intern_strings(void) @@ -4899,6 +4900,10 @@ intern_strings(void) if (npy_ma_str_numpy == NULL) { return -1; } + npy_ma_str_where = PyUnicode_InternFromString("where"); + if (npy_ma_str_where == NULL) { + return -1; + } return 0; } diff --git a/numpy/core/src/multiarray/multiarraymodule.h b/numpy/core/src/multiarray/multiarraymodule.h index 992acd09f..9ba2a1831 100644 --- a/numpy/core/src/multiarray/multiarraymodule.h +++ b/numpy/core/src/multiarray/multiarraymodule.h @@ -16,5 +16,6 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis1; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis2; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_like; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_numpy; +NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_where; #endif /* NUMPY_CORE_SRC_MULTIARRAY_MULTIARRAYMODULE_H_ */ diff --git a/numpy/core/src/umath/override.c b/numpy/core/src/umath/override.c index d247c2639..167164163 100644 --- a/numpy/core/src/umath/override.c +++ b/numpy/core/src/umath/override.c @@ -23,18 +23,19 @@ * Returns -1 on failure. */ static int -get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args, +get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj, PyObject **with_override, PyObject **methods) { int i; int num_override_args = 0; - int narg, nout; + int narg, nout, nwhere; narg = (int)PyTuple_GET_SIZE(in_args); /* It is valid for out_args to be NULL: */ nout = (out_args != NULL) ? (int)PyTuple_GET_SIZE(out_args) : 0; + nwhere = (wheremask_obj != NULL) ? 1: 0; - for (i = 0; i < narg + nout; ++i) { + for (i = 0; i < narg + nout + nwhere; ++i) { PyObject *obj; int j; int new_class = 1; @@ -42,9 +43,12 @@ get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args, if (i < narg) { obj = PyTuple_GET_ITEM(in_args, i); } - else { + else if (i < narg + nout){ obj = PyTuple_GET_ITEM(out_args, i - narg); } + else { + obj = wheremask_obj; + } /* * Have we seen this class before? If so, ignore. */ @@ -208,7 +212,7 @@ copy_positional_args_to_kwargs(const char **keywords, */ NPY_NO_EXPORT int PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, - PyObject *in_args, PyObject *out_args, + PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj, PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames, PyObject **result) { @@ -227,7 +231,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, * Check inputs for overrides */ num_override_args = get_array_ufunc_overrides( - in_args, out_args, with_override, array_ufunc_methods); + in_args, out_args, wheremask_obj, with_override, array_ufunc_methods); if (num_override_args == -1) { goto fail; } diff --git a/numpy/core/src/umath/override.h b/numpy/core/src/umath/override.h index 4e9a323ca..20621bb19 100644 --- a/numpy/core/src/umath/override.h +++ b/numpy/core/src/umath/override.h @@ -6,7 +6,7 @@ NPY_NO_EXPORT int PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, - PyObject *in_args, PyObject *out_args, + PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj, PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames, PyObject **result); diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index a159003de..a5e8f4cbe 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -4071,7 +4071,7 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, /* We now have all the information required to check for Overrides */ PyObject *override = NULL; int errval = PyUFunc_CheckOverride(ufunc, _reduce_type[operation], - full_args.in, full_args.out, args, len_args, kwnames, &override); + full_args.in, full_args.out, wheremask_obj, args, len_args, kwnames, &override); if (errval) { return NULL; } @@ -4843,7 +4843,7 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc, /* We now have all the information required to check for Overrides */ PyObject *override = NULL; errval = PyUFunc_CheckOverride(ufunc, method, - full_args.in, full_args.out, + full_args.in, full_args.out, where_obj, args, len_args, kwnames, &override); if (errval) { goto fail; @@ -6261,7 +6261,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) return NULL; } errval = PyUFunc_CheckOverride(ufunc, "at", - args, NULL, NULL, 0, NULL, &override); + args, NULL, NULL, NULL, 0, NULL, &override); if (errval) { return NULL; |