summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/umath/dispatching.c74
-rw-r--r--numpy/core/src/umath/dispatching.h4
2 files changed, 68 insertions, 10 deletions
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c
index 1d3ad9dff..6bad5bd38 100644
--- a/numpy/core/src/umath/dispatching.c
+++ b/numpy/core/src/umath/dispatching.c
@@ -97,8 +97,9 @@ PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate)
return -1;
}
}
- if (!PyObject_TypeCheck(PyTuple_GET_ITEM(info, 1), &PyArrayMethod_Type)) {
- /* Must also accept promoters in the future. */
+ PyObject *meth_or_promoter = PyTuple_GET_ITEM(info, 1);
+ if (!PyObject_TypeCheck(meth_or_promoter, &PyArrayMethod_Type)
+ && !PyCapsule_IsValid(meth_or_promoter, "numpy._ufunc_promoter")) {
PyErr_SetString(PyExc_TypeError,
"Second argument to info must be an ArrayMethod or promoter");
return -1;
@@ -353,15 +354,68 @@ resolve_implementation_info(PyUFuncObject *ufunc,
* those defined by the `signature` unmodified).
*/
static PyObject *
-call_promoter_and_recurse(
- PyUFuncObject *NPY_UNUSED(ufunc), PyObject *NPY_UNUSED(promoter),
- PyArray_DTypeMeta *NPY_UNUSED(op_dtypes[]),
- PyArray_DTypeMeta *NPY_UNUSED(signature[]),
- PyArrayObject *const NPY_UNUSED(operands[]))
+call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *promoter,
+ PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
+ PyArrayObject *const operands[])
{
- PyErr_SetString(PyExc_NotImplementedError,
- "Internal NumPy error, promoters are not used/implemented yet.");
- return NULL;
+ int nargs = ufunc->nargs;
+ PyObject *resolved_info = NULL;
+
+ int promoter_result;
+ PyArray_DTypeMeta *new_op_dtypes[NPY_MAXARGS];
+
+ if (PyCapsule_CheckExact(promoter)) {
+ /* We could also go the other way and wrap up the python function... */
+ promoter_function *promoter_function = PyCapsule_GetPointer(promoter,
+ "numpy._ufunc_promoter");
+ if (promoter_function == NULL) {
+ return NULL;
+ }
+ promoter_result = promoter_function(ufunc,
+ op_dtypes, signature, new_op_dtypes);
+ }
+ else {
+ PyErr_SetString(PyExc_NotImplementedError,
+ "Calling python functions for promotion is not implemented.");
+ return NULL;
+ }
+ if (promoter_result < 0) {
+ return NULL;
+ }
+ /*
+ * If none of the dtypes changes, we would recurse infinitely, abort.
+ * (Of course it is nevertheless possible to recurse infinitely.)
+ */
+ int dtypes_changed = 0;
+ for (int i = 0; i < nargs; i++) {
+ if (new_op_dtypes[i] != op_dtypes[i]) {
+ dtypes_changed = 1;
+ break;
+ }
+ }
+ if (!dtypes_changed) {
+ goto finish;
+ }
+
+ /*
+ * Do a recursive call, the promotion function has to ensure that the
+ * new tuple is strictly more precise (thus guaranteeing eventual finishing)
+ */
+ if (Py_EnterRecursiveCall(" during ufunc promotion.") != 0) {
+ goto finish;
+ }
+ /* TODO: The caching logic here may need revising: */
+ resolved_info = promote_and_get_info_and_ufuncimpl(ufunc,
+ operands, signature, new_op_dtypes,
+ /* no legacy promotion */ NPY_FALSE, /* cache */ NPY_TRUE);
+
+ Py_LeaveRecursiveCall();
+
+ finish:
+ for (int i = 0; i < nargs; i++) {
+ Py_XDECREF(new_op_dtypes[i]);
+ }
+ return resolved_info;
}
diff --git a/numpy/core/src/umath/dispatching.h b/numpy/core/src/umath/dispatching.h
index b01bc79fa..8d116873c 100644
--- a/numpy/core/src/umath/dispatching.h
+++ b/numpy/core/src/umath/dispatching.h
@@ -7,6 +7,10 @@
#include "array_method.h"
+typedef int promoter_function(PyUFuncObject *ufunc,
+ PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
+ PyArray_DTypeMeta *new_op_dtypes[]);
+
NPY_NO_EXPORT int
PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate);