diff options
-rw-r--r-- | numpy/core/src/umath/dispatching.c | 74 | ||||
-rw-r--r-- | numpy/core/src/umath/dispatching.h | 4 |
2 files changed, 68 insertions, 10 deletions
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c index 1d3ad9dff..6bad5bd38 100644 --- a/numpy/core/src/umath/dispatching.c +++ b/numpy/core/src/umath/dispatching.c @@ -97,8 +97,9 @@ PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate) return -1; } } - if (!PyObject_TypeCheck(PyTuple_GET_ITEM(info, 1), &PyArrayMethod_Type)) { - /* Must also accept promoters in the future. */ + PyObject *meth_or_promoter = PyTuple_GET_ITEM(info, 1); + if (!PyObject_TypeCheck(meth_or_promoter, &PyArrayMethod_Type) + && !PyCapsule_IsValid(meth_or_promoter, "numpy._ufunc_promoter")) { PyErr_SetString(PyExc_TypeError, "Second argument to info must be an ArrayMethod or promoter"); return -1; @@ -353,15 +354,68 @@ resolve_implementation_info(PyUFuncObject *ufunc, * those defined by the `signature` unmodified). */ static PyObject * -call_promoter_and_recurse( - PyUFuncObject *NPY_UNUSED(ufunc), PyObject *NPY_UNUSED(promoter), - PyArray_DTypeMeta *NPY_UNUSED(op_dtypes[]), - PyArray_DTypeMeta *NPY_UNUSED(signature[]), - PyArrayObject *const NPY_UNUSED(operands[])) +call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *promoter, + PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[], + PyArrayObject *const operands[]) { - PyErr_SetString(PyExc_NotImplementedError, - "Internal NumPy error, promoters are not used/implemented yet."); - return NULL; + int nargs = ufunc->nargs; + PyObject *resolved_info = NULL; + + int promoter_result; + PyArray_DTypeMeta *new_op_dtypes[NPY_MAXARGS]; + + if (PyCapsule_CheckExact(promoter)) { + /* We could also go the other way and wrap up the python function... */ + promoter_function *promoter_function = PyCapsule_GetPointer(promoter, + "numpy._ufunc_promoter"); + if (promoter_function == NULL) { + return NULL; + } + promoter_result = promoter_function(ufunc, + op_dtypes, signature, new_op_dtypes); + } + else { + PyErr_SetString(PyExc_NotImplementedError, + "Calling python functions for promotion is not implemented."); + return NULL; + } + if (promoter_result < 0) { + return NULL; + } + /* + * If none of the dtypes changes, we would recurse infinitely, abort. + * (Of course it is nevertheless possible to recurse infinitely.) + */ + int dtypes_changed = 0; + for (int i = 0; i < nargs; i++) { + if (new_op_dtypes[i] != op_dtypes[i]) { + dtypes_changed = 1; + break; + } + } + if (!dtypes_changed) { + goto finish; + } + + /* + * Do a recursive call, the promotion function has to ensure that the + * new tuple is strictly more precise (thus guaranteeing eventual finishing) + */ + if (Py_EnterRecursiveCall(" during ufunc promotion.") != 0) { + goto finish; + } + /* TODO: The caching logic here may need revising: */ + resolved_info = promote_and_get_info_and_ufuncimpl(ufunc, + operands, signature, new_op_dtypes, + /* no legacy promotion */ NPY_FALSE, /* cache */ NPY_TRUE); + + Py_LeaveRecursiveCall(); + + finish: + for (int i = 0; i < nargs; i++) { + Py_XDECREF(new_op_dtypes[i]); + } + return resolved_info; } diff --git a/numpy/core/src/umath/dispatching.h b/numpy/core/src/umath/dispatching.h index b01bc79fa..8d116873c 100644 --- a/numpy/core/src/umath/dispatching.h +++ b/numpy/core/src/umath/dispatching.h @@ -7,6 +7,10 @@ #include "array_method.h" +typedef int promoter_function(PyUFuncObject *ufunc, + PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[], + PyArray_DTypeMeta *new_op_dtypes[]); + NPY_NO_EXPORT int PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate); |