diff options
-rw-r--r-- | numpy/core/src/umath/_scaled_float_dtype.c | 67 | ||||
-rw-r--r-- | numpy/core/src/umath/dispatching.c | 74 | ||||
-rw-r--r-- | numpy/core/src/umath/dispatching.h | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_custom_dtypes.py | 12 |
4 files changed, 138 insertions, 19 deletions
diff --git a/numpy/core/src/umath/_scaled_float_dtype.c b/numpy/core/src/umath/_scaled_float_dtype.c index 599774cce..cbea378f0 100644 --- a/numpy/core/src/umath/_scaled_float_dtype.c +++ b/numpy/core/src/umath/_scaled_float_dtype.c @@ -464,9 +464,6 @@ init_casts(void) * 2. Addition, which needs to use the common instance, and runs into * cast safety subtleties since we will implement it without an additional * cast. - * - * NOTE: When first writing this, promotion did not exist for new-style loops, - * if it exists, we could use promotion to implement double * sfloat. */ static int multiply_sfloats(PyArrayMethod_Context *NPY_UNUSED(context), @@ -591,7 +588,8 @@ add_sfloats_resolve_descriptors( static int -add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth) +add_loop(const char *ufunc_name, + PyArray_DTypeMeta *dtypes[3], PyObject *meth_or_promoter) { PyObject *mod = PyImport_ImportModule("numpy"); if (mod == NULL) { @@ -605,13 +603,12 @@ add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth) "numpy.%s was not a ufunc!", ufunc_name); return -1; } - PyObject *dtype_tup = PyArray_TupleFromItems( - 3, (PyObject **)bmeth->dtypes, 0); + PyObject *dtype_tup = PyArray_TupleFromItems(3, (PyObject **)dtypes, 1); if (dtype_tup == NULL) { Py_DECREF(ufunc); return -1; } - PyObject *info = PyTuple_Pack(2, dtype_tup, bmeth->method); + PyObject *info = PyTuple_Pack(2, dtype_tup, meth_or_promoter); Py_DECREF(dtype_tup); if (info == NULL) { Py_DECREF(ufunc); @@ -624,6 +621,28 @@ add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth) } + +/* + * We add some very basic promoters to allow multiplying normal and scaled + */ +static int +promote_to_sfloat(PyUFuncObject *NPY_UNUSED(ufunc), + PyArray_DTypeMeta *const NPY_UNUSED(dtypes[3]), + PyArray_DTypeMeta *const signature[3], + PyArray_DTypeMeta *new_dtypes[3]) +{ + for (int i = 0; i < 3; i++) { + PyArray_DTypeMeta *new = &PyArray_SFloatDType; + if (signature[i] != NULL) { + new = signature[i]; + } + Py_INCREF(new); + new_dtypes[i] = new; + } + return 0; +} + + /* * Add new ufunc loops (this is somewhat clumsy as of writing it, but should * get less so with the introduction of public API). @@ -650,7 +669,8 @@ init_ufuncs(void) { if (bmeth == NULL) { return -1; } - int res = add_loop("multiply", bmeth); + int res = add_loop("multiply", + bmeth->dtypes, (PyObject *)bmeth->method); Py_DECREF(bmeth); if (res < 0) { return -1; @@ -667,11 +687,40 @@ init_ufuncs(void) { if (bmeth == NULL) { return -1; } - res = add_loop("add", bmeth); + res = add_loop("add", + bmeth->dtypes, (PyObject *)bmeth->method); Py_DECREF(bmeth); if (res < 0) { return -1; } + + /* + * Add a promoter for both directions of multiply with double. + */ + PyArray_DTypeMeta *double_DType = PyArray_DTypeFromTypeNum(NPY_DOUBLE); + Py_DECREF(double_DType); /* immortal anyway */ + + PyArray_DTypeMeta *promoter_dtypes[3] = { + &PyArray_SFloatDType, double_DType, NULL}; + + PyObject *promoter = PyCapsule_New( + &promote_to_sfloat, "numpy._ufunc_promoter", NULL); + if (promoter == NULL) { + return -1; + } + res = add_loop("multiply", promoter_dtypes, promoter); + if (res < 0) { + Py_DECREF(promoter); + return -1; + } + promoter_dtypes[0] = double_DType; + promoter_dtypes[1] = &PyArray_SFloatDType; + res = add_loop("multiply", promoter_dtypes, promoter); + Py_DECREF(promoter); + if (res < 0) { + return -1; + } + return 0; } diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c index b1c5ccb6b..b97441b13 100644 --- a/numpy/core/src/umath/dispatching.c +++ b/numpy/core/src/umath/dispatching.c @@ -97,8 +97,9 @@ PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate) return -1; } } - if (!PyObject_TypeCheck(PyTuple_GET_ITEM(info, 1), &PyArrayMethod_Type)) { - /* Must also accept promoters in the future. */ + PyObject *meth_or_promoter = PyTuple_GET_ITEM(info, 1); + if (!PyObject_TypeCheck(meth_or_promoter, &PyArrayMethod_Type) + && !PyCapsule_IsValid(meth_or_promoter, "numpy._ufunc_promoter")) { PyErr_SetString(PyExc_TypeError, "Second argument to info must be an ArrayMethod or promoter"); return -1; @@ -354,15 +355,68 @@ resolve_implementation_info(PyUFuncObject *ufunc, * those defined by the `signature` unmodified). */ static PyObject * -call_promoter_and_recurse( - PyUFuncObject *NPY_UNUSED(ufunc), PyObject *NPY_UNUSED(promoter), - PyArray_DTypeMeta *NPY_UNUSED(op_dtypes[]), - PyArray_DTypeMeta *NPY_UNUSED(signature[]), - PyArrayObject *const NPY_UNUSED(operands[])) +call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *promoter, + PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[], + PyArrayObject *const operands[]) { - PyErr_SetString(PyExc_NotImplementedError, - "Internal NumPy error, promoters are not used/implemented yet."); - return NULL; + int nargs = ufunc->nargs; + PyObject *resolved_info = NULL; + + int promoter_result; + PyArray_DTypeMeta *new_op_dtypes[NPY_MAXARGS]; + + if (PyCapsule_CheckExact(promoter)) { + /* We could also go the other way and wrap up the python function... */ + promoter_function *promoter_function = PyCapsule_GetPointer(promoter, + "numpy._ufunc_promoter"); + if (promoter_function == NULL) { + return NULL; + } + promoter_result = promoter_function(ufunc, + op_dtypes, signature, new_op_dtypes); + } + else { + PyErr_SetString(PyExc_NotImplementedError, + "Calling python functions for promotion is not implemented."); + return NULL; + } + if (promoter_result < 0) { + return NULL; + } + /* + * If none of the dtypes changes, we would recurse infinitely, abort. + * (Of course it is nevertheless possible to recurse infinitely.) + */ + int dtypes_changed = 0; + for (int i = 0; i < nargs; i++) { + if (new_op_dtypes[i] != op_dtypes[i]) { + dtypes_changed = 1; + break; + } + } + if (!dtypes_changed) { + goto finish; + } + + /* + * Do a recursive call, the promotion function has to ensure that the + * new tuple is strictly more precise (thus guaranteeing eventual finishing) + */ + if (Py_EnterRecursiveCall(" during ufunc promotion.") != 0) { + goto finish; + } + /* TODO: The caching logic here may need revising: */ + resolved_info = promote_and_get_info_and_ufuncimpl(ufunc, + operands, signature, new_op_dtypes, + /* no legacy promotion */ NPY_FALSE, /* cache */ NPY_TRUE); + + Py_LeaveRecursiveCall(); + + finish: + for (int i = 0; i < nargs; i++) { + Py_XDECREF(new_op_dtypes[i]); + } + return resolved_info; } diff --git a/numpy/core/src/umath/dispatching.h b/numpy/core/src/umath/dispatching.h index b01bc79fa..8d116873c 100644 --- a/numpy/core/src/umath/dispatching.h +++ b/numpy/core/src/umath/dispatching.h @@ -7,6 +7,10 @@ #include "array_method.h" +typedef int promoter_function(PyUFuncObject *ufunc, + PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[], + PyArray_DTypeMeta *new_op_dtypes[]); + NPY_NO_EXPORT int PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate); diff --git a/numpy/core/tests/test_custom_dtypes.py b/numpy/core/tests/test_custom_dtypes.py index 3ec2363b9..5eb82bc93 100644 --- a/numpy/core/tests/test_custom_dtypes.py +++ b/numpy/core/tests/test_custom_dtypes.py @@ -101,6 +101,18 @@ class TestSFloat: expected_view = a.view(np.float64) * b.view(np.float64) assert_array_equal(res.view(np.float64), expected_view) + def test_basic_multiply_promotion(self): + float_a = np.array([1., 2., 3.]) + b = self._get_array(2.) + + res1 = float_a * b + res2 = b * float_a + # one factor is one, so we get the factor of b: + assert res1.dtype == res2.dtype == b.dtype + expected_view = float_a * b.view(np.float64) + assert_array_equal(res1.view(np.float64), expected_view) + assert_array_equal(res2.view(np.float64), expected_view) + def test_basic_addition(self): a = self._get_array(2.) b = self._get_array(4.) |