summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-03-22 12:20:00 +0100
committerGitHub <noreply@github.com>2023-03-22 12:20:00 +0100
commitb35aac2c35ccfd5efadd7f72a090c9ad99308a60 (patch)
treeb1bd47eb0b5f68488379d8ea631aaa9c360a7552 /numpy/core/src
parent294c7f2c893b7e5ef783fc1cb1912d06404b452b (diff)
parentf3f108d313a8b8a4f7a90fb932867f17dc48b1f6 (diff)
downloadnumpy-b35aac2c35ccfd5efadd7f72a090c9ad99308a60.tar.gz
Merge pull request #23240 from byrdie/bugfix/ufunc_where_propagation
ENH: Allow ``where`` argument to override ``__array_ufunc__``
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/multiarray/methods.c14
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c5
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.h1
-rw-r--r--numpy/core/src/umath/override.c16
-rw-r--r--numpy/core/src/umath/override.h2
-rw-r--r--numpy/core/src/umath/ufunc_object.c6
6 files changed, 33 insertions, 11 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index f518f3a02..93b290020 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -28,6 +28,7 @@
#include "strfuncs.h"
#include "array_assign.h"
#include "npy_dlpack.h"
+#include "multiarraymodule.h"
#include "methods.h"
#include "alloc.h"
@@ -1102,7 +1103,7 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds)
int nin, nout;
PyObject *out_kwd_obj;
PyObject *fast;
- PyObject **in_objs, **out_objs;
+ PyObject **in_objs, **out_objs, *where_obj;
/* check inputs */
nin = PyTuple_Size(args);
@@ -1133,6 +1134,17 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds)
}
}
Py_DECREF(out_kwd_obj);
+ /* check where if it exists */
+ where_obj = PyDict_GetItemWithError(kwds, npy_ma_str_where);
+ if (where_obj == NULL) {
+ if (PyErr_Occurred()) {
+ return -1;
+ }
+ } else {
+ if (PyUFunc_HasOverride(where_obj)){
+ return 1;
+ }
+ }
return 0;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index e85f8affa..ac8e641b7 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -4843,6 +4843,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis1 = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis2 = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_like = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_numpy = NULL;
+NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_where = NULL;
static int
intern_strings(void)
@@ -4899,6 +4900,10 @@ intern_strings(void)
if (npy_ma_str_numpy == NULL) {
return -1;
}
+ npy_ma_str_where = PyUnicode_InternFromString("where");
+ if (npy_ma_str_where == NULL) {
+ return -1;
+ }
return 0;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.h b/numpy/core/src/multiarray/multiarraymodule.h
index 992acd09f..9ba2a1831 100644
--- a/numpy/core/src/multiarray/multiarraymodule.h
+++ b/numpy/core/src/multiarray/multiarraymodule.h
@@ -16,5 +16,6 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis1;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis2;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_like;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_numpy;
+NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_where;
#endif /* NUMPY_CORE_SRC_MULTIARRAY_MULTIARRAYMODULE_H_ */
diff --git a/numpy/core/src/umath/override.c b/numpy/core/src/umath/override.c
index d247c2639..167164163 100644
--- a/numpy/core/src/umath/override.c
+++ b/numpy/core/src/umath/override.c
@@ -23,18 +23,19 @@
* Returns -1 on failure.
*/
static int
-get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args,
+get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj,
PyObject **with_override, PyObject **methods)
{
int i;
int num_override_args = 0;
- int narg, nout;
+ int narg, nout, nwhere;
narg = (int)PyTuple_GET_SIZE(in_args);
/* It is valid for out_args to be NULL: */
nout = (out_args != NULL) ? (int)PyTuple_GET_SIZE(out_args) : 0;
+ nwhere = (wheremask_obj != NULL) ? 1: 0;
- for (i = 0; i < narg + nout; ++i) {
+ for (i = 0; i < narg + nout + nwhere; ++i) {
PyObject *obj;
int j;
int new_class = 1;
@@ -42,9 +43,12 @@ get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args,
if (i < narg) {
obj = PyTuple_GET_ITEM(in_args, i);
}
- else {
+ else if (i < narg + nout){
obj = PyTuple_GET_ITEM(out_args, i - narg);
}
+ else {
+ obj = wheremask_obj;
+ }
/*
* Have we seen this class before? If so, ignore.
*/
@@ -208,7 +212,7 @@ copy_positional_args_to_kwargs(const char **keywords,
*/
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
- PyObject *in_args, PyObject *out_args,
+ PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames,
PyObject **result)
{
@@ -227,7 +231,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
* Check inputs for overrides
*/
num_override_args = get_array_ufunc_overrides(
- in_args, out_args, with_override, array_ufunc_methods);
+ in_args, out_args, wheremask_obj, with_override, array_ufunc_methods);
if (num_override_args == -1) {
goto fail;
}
diff --git a/numpy/core/src/umath/override.h b/numpy/core/src/umath/override.h
index 4e9a323ca..20621bb19 100644
--- a/numpy/core/src/umath/override.h
+++ b/numpy/core/src/umath/override.h
@@ -6,7 +6,7 @@
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
- PyObject *in_args, PyObject *out_args,
+ PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames,
PyObject **result);
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index a159003de..a5e8f4cbe 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -4071,7 +4071,7 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc,
/* We now have all the information required to check for Overrides */
PyObject *override = NULL;
int errval = PyUFunc_CheckOverride(ufunc, _reduce_type[operation],
- full_args.in, full_args.out, args, len_args, kwnames, &override);
+ full_args.in, full_args.out, wheremask_obj, args, len_args, kwnames, &override);
if (errval) {
return NULL;
}
@@ -4843,7 +4843,7 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc,
/* We now have all the information required to check for Overrides */
PyObject *override = NULL;
errval = PyUFunc_CheckOverride(ufunc, method,
- full_args.in, full_args.out,
+ full_args.in, full_args.out, where_obj,
args, len_args, kwnames, &override);
if (errval) {
goto fail;
@@ -6261,7 +6261,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
return NULL;
}
errval = PyUFunc_CheckOverride(ufunc, "at",
- args, NULL, NULL, 0, NULL, &override);
+ args, NULL, NULL, NULL, 0, NULL, &override);
if (errval) {
return NULL;