summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2021-08-08 11:55:57 +0300
committerGitHub <noreply@github.com>2021-08-08 11:55:57 +0300
commit455f1f3168ab0105d05add646ff63a17b80fdcdc (patch)
treee134b3f1f35af03e69024cb8e4c601faabe32a32 /numpy
parent261a769dd135c1e4007a6d0bdbed68a8c929497a (diff)
parent693966c8e6677cc573f19f436aa2a310d986e4ed (diff)
downloadnumpy-455f1f3168ab0105d05add646ff63a17b80fdcdc.tar.gz
Merge pull request #19580 from seberg/basic-promotion
ENH: Add basic promoter capability to ufunc dispatching
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/_scaled_float_dtype.c67
-rw-r--r--numpy/core/src/umath/dispatching.c74
-rw-r--r--numpy/core/src/umath/dispatching.h4
-rw-r--r--numpy/core/tests/test_custom_dtypes.py12
4 files changed, 138 insertions, 19 deletions
diff --git a/numpy/core/src/umath/_scaled_float_dtype.c b/numpy/core/src/umath/_scaled_float_dtype.c
index 599774cce..cbea378f0 100644
--- a/numpy/core/src/umath/_scaled_float_dtype.c
+++ b/numpy/core/src/umath/_scaled_float_dtype.c
@@ -464,9 +464,6 @@ init_casts(void)
* 2. Addition, which needs to use the common instance, and runs into
* cast safety subtleties since we will implement it without an additional
* cast.
- *
- * NOTE: When first writing this, promotion did not exist for new-style loops,
- * if it exists, we could use promotion to implement double * sfloat.
*/
static int
multiply_sfloats(PyArrayMethod_Context *NPY_UNUSED(context),
@@ -591,7 +588,8 @@ add_sfloats_resolve_descriptors(
static int
-add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth)
+add_loop(const char *ufunc_name,
+ PyArray_DTypeMeta *dtypes[3], PyObject *meth_or_promoter)
{
PyObject *mod = PyImport_ImportModule("numpy");
if (mod == NULL) {
@@ -605,13 +603,12 @@ add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth)
"numpy.%s was not a ufunc!", ufunc_name);
return -1;
}
- PyObject *dtype_tup = PyArray_TupleFromItems(
- 3, (PyObject **)bmeth->dtypes, 0);
+ PyObject *dtype_tup = PyArray_TupleFromItems(3, (PyObject **)dtypes, 1);
if (dtype_tup == NULL) {
Py_DECREF(ufunc);
return -1;
}
- PyObject *info = PyTuple_Pack(2, dtype_tup, bmeth->method);
+ PyObject *info = PyTuple_Pack(2, dtype_tup, meth_or_promoter);
Py_DECREF(dtype_tup);
if (info == NULL) {
Py_DECREF(ufunc);
@@ -624,6 +621,28 @@ add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth)
}
+
+/*
+ * We add some very basic promoters to allow multiplying normal and scaled
+ */
+static int
+promote_to_sfloat(PyUFuncObject *NPY_UNUSED(ufunc),
+ PyArray_DTypeMeta *const NPY_UNUSED(dtypes[3]),
+ PyArray_DTypeMeta *const signature[3],
+ PyArray_DTypeMeta *new_dtypes[3])
+{
+ for (int i = 0; i < 3; i++) {
+ PyArray_DTypeMeta *new = &PyArray_SFloatDType;
+ if (signature[i] != NULL) {
+ new = signature[i];
+ }
+ Py_INCREF(new);
+ new_dtypes[i] = new;
+ }
+ return 0;
+}
+
+
/*
* Add new ufunc loops (this is somewhat clumsy as of writing it, but should
* get less so with the introduction of public API).
@@ -650,7 +669,8 @@ init_ufuncs(void) {
if (bmeth == NULL) {
return -1;
}
- int res = add_loop("multiply", bmeth);
+ int res = add_loop("multiply",
+ bmeth->dtypes, (PyObject *)bmeth->method);
Py_DECREF(bmeth);
if (res < 0) {
return -1;
@@ -667,11 +687,40 @@ init_ufuncs(void) {
if (bmeth == NULL) {
return -1;
}
- res = add_loop("add", bmeth);
+ res = add_loop("add",
+ bmeth->dtypes, (PyObject *)bmeth->method);
Py_DECREF(bmeth);
if (res < 0) {
return -1;
}
+
+ /*
+ * Add a promoter for both directions of multiply with double.
+ */
+ PyArray_DTypeMeta *double_DType = PyArray_DTypeFromTypeNum(NPY_DOUBLE);
+ Py_DECREF(double_DType); /* immortal anyway */
+
+ PyArray_DTypeMeta *promoter_dtypes[3] = {
+ &PyArray_SFloatDType, double_DType, NULL};
+
+ PyObject *promoter = PyCapsule_New(
+ &promote_to_sfloat, "numpy._ufunc_promoter", NULL);
+ if (promoter == NULL) {
+ return -1;
+ }
+ res = add_loop("multiply", promoter_dtypes, promoter);
+ if (res < 0) {
+ Py_DECREF(promoter);
+ return -1;
+ }
+ promoter_dtypes[0] = double_DType;
+ promoter_dtypes[1] = &PyArray_SFloatDType;
+ res = add_loop("multiply", promoter_dtypes, promoter);
+ Py_DECREF(promoter);
+ if (res < 0) {
+ return -1;
+ }
+
return 0;
}
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c
index b1c5ccb6b..b97441b13 100644
--- a/numpy/core/src/umath/dispatching.c
+++ b/numpy/core/src/umath/dispatching.c
@@ -97,8 +97,9 @@ PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate)
return -1;
}
}
- if (!PyObject_TypeCheck(PyTuple_GET_ITEM(info, 1), &PyArrayMethod_Type)) {
- /* Must also accept promoters in the future. */
+ PyObject *meth_or_promoter = PyTuple_GET_ITEM(info, 1);
+ if (!PyObject_TypeCheck(meth_or_promoter, &PyArrayMethod_Type)
+ && !PyCapsule_IsValid(meth_or_promoter, "numpy._ufunc_promoter")) {
PyErr_SetString(PyExc_TypeError,
"Second argument to info must be an ArrayMethod or promoter");
return -1;
@@ -354,15 +355,68 @@ resolve_implementation_info(PyUFuncObject *ufunc,
* those defined by the `signature` unmodified).
*/
static PyObject *
-call_promoter_and_recurse(
- PyUFuncObject *NPY_UNUSED(ufunc), PyObject *NPY_UNUSED(promoter),
- PyArray_DTypeMeta *NPY_UNUSED(op_dtypes[]),
- PyArray_DTypeMeta *NPY_UNUSED(signature[]),
- PyArrayObject *const NPY_UNUSED(operands[]))
+call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *promoter,
+ PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
+ PyArrayObject *const operands[])
{
- PyErr_SetString(PyExc_NotImplementedError,
- "Internal NumPy error, promoters are not used/implemented yet.");
- return NULL;
+ int nargs = ufunc->nargs;
+ PyObject *resolved_info = NULL;
+
+ int promoter_result;
+ PyArray_DTypeMeta *new_op_dtypes[NPY_MAXARGS];
+
+ if (PyCapsule_CheckExact(promoter)) {
+ /* We could also go the other way and wrap up the python function... */
+ promoter_function *promoter_function = PyCapsule_GetPointer(promoter,
+ "numpy._ufunc_promoter");
+ if (promoter_function == NULL) {
+ return NULL;
+ }
+ promoter_result = promoter_function(ufunc,
+ op_dtypes, signature, new_op_dtypes);
+ }
+ else {
+ PyErr_SetString(PyExc_NotImplementedError,
+ "Calling python functions for promotion is not implemented.");
+ return NULL;
+ }
+ if (promoter_result < 0) {
+ return NULL;
+ }
+ /*
+ * If none of the dtypes changes, we would recurse infinitely, abort.
+ * (Of course it is nevertheless possible to recurse infinitely.)
+ */
+ int dtypes_changed = 0;
+ for (int i = 0; i < nargs; i++) {
+ if (new_op_dtypes[i] != op_dtypes[i]) {
+ dtypes_changed = 1;
+ break;
+ }
+ }
+ if (!dtypes_changed) {
+ goto finish;
+ }
+
+ /*
+ * Do a recursive call, the promotion function has to ensure that the
+ * new tuple is strictly more precise (thus guaranteeing eventual finishing)
+ */
+ if (Py_EnterRecursiveCall(" during ufunc promotion.") != 0) {
+ goto finish;
+ }
+ /* TODO: The caching logic here may need revising: */
+ resolved_info = promote_and_get_info_and_ufuncimpl(ufunc,
+ operands, signature, new_op_dtypes,
+ /* no legacy promotion */ NPY_FALSE, /* cache */ NPY_TRUE);
+
+ Py_LeaveRecursiveCall();
+
+ finish:
+ for (int i = 0; i < nargs; i++) {
+ Py_XDECREF(new_op_dtypes[i]);
+ }
+ return resolved_info;
}
diff --git a/numpy/core/src/umath/dispatching.h b/numpy/core/src/umath/dispatching.h
index b01bc79fa..8d116873c 100644
--- a/numpy/core/src/umath/dispatching.h
+++ b/numpy/core/src/umath/dispatching.h
@@ -7,6 +7,10 @@
#include "array_method.h"
+typedef int promoter_function(PyUFuncObject *ufunc,
+ PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
+ PyArray_DTypeMeta *new_op_dtypes[]);
+
NPY_NO_EXPORT int
PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate);
diff --git a/numpy/core/tests/test_custom_dtypes.py b/numpy/core/tests/test_custom_dtypes.py
index 3ec2363b9..5eb82bc93 100644
--- a/numpy/core/tests/test_custom_dtypes.py
+++ b/numpy/core/tests/test_custom_dtypes.py
@@ -101,6 +101,18 @@ class TestSFloat:
expected_view = a.view(np.float64) * b.view(np.float64)
assert_array_equal(res.view(np.float64), expected_view)
+ def test_basic_multiply_promotion(self):
+ float_a = np.array([1., 2., 3.])
+ b = self._get_array(2.)
+
+ res1 = float_a * b
+ res2 = b * float_a
+ # one factor is one, so we get the factor of b:
+ assert res1.dtype == res2.dtype == b.dtype
+ expected_view = float_a * b.view(np.float64)
+ assert_array_equal(res1.view(np.float64), expected_view)
+ assert_array_equal(res2.view(np.float64), expected_view)
+
def test_basic_addition(self):
a = self._get_array(2.)
b = self._get_array(4.)