summaryrefslogtreecommitdiff
path: root/numpy/core/src/umath/_scaled_float_dtype.c
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src/umath/_scaled_float_dtype.c')
-rw-r--r--numpy/core/src/umath/_scaled_float_dtype.c128
1 files changed, 106 insertions, 22 deletions
diff --git a/numpy/core/src/umath/_scaled_float_dtype.c b/numpy/core/src/umath/_scaled_float_dtype.c
index a214b32aa..c26ace9f1 100644
--- a/numpy/core/src/umath/_scaled_float_dtype.c
+++ b/numpy/core/src/umath/_scaled_float_dtype.c
@@ -26,6 +26,14 @@
#include "dispatching.h"
+/* TODO: from wrapping_array_method.c, use proper public header eventually */
+NPY_NO_EXPORT int
+PyUFunc_AddWrappingLoop(PyObject *ufunc_obj,
+ PyArray_DTypeMeta *new_dtypes[], PyArray_DTypeMeta *wrapped_dtypes[],
+ translate_given_descrs_func *translate_given_descrs,
+ translate_loop_descrs_func *translate_loop_descrs);
+
+
typedef struct {
PyArray_Descr base;
double scaling;
@@ -125,9 +133,9 @@ sfloat_setitem(PyObject *obj, char *data, PyArrayObject *arr)
/* Special DType methods and the descr->f slot storage */
NPY_DType_Slots sfloat_slots = {
- .default_descr = &sfloat_default_descr,
.discover_descr_from_pyobject = &sfloat_discover_from_pyobject,
.is_known_scalar_type = &sfloat_is_known_scalar_type,
+ .default_descr = &sfloat_default_descr,
.common_dtype = &sfloat_common_dtype,
.common_instance = &sfloat_common_instance,
.f = {
@@ -136,14 +144,13 @@ NPY_DType_Slots sfloat_slots = {
}
};
-
static PyArray_SFloatDescr SFloatSingleton = {{
- .elsize = sizeof(double),
- .alignment = _ALIGN(double),
+ .byteorder = '|', /* do not bother with byte-swapping... */
.flags = NPY_USE_GETITEM|NPY_USE_SETITEM,
.type_num = -1,
+ .elsize = sizeof(double),
+ .alignment = _ALIGN(double),
.f = &sfloat_slots.f,
- .byteorder = '|', /* do not bother with byte-swapping... */
},
.scaling = 1,
};
@@ -233,15 +240,15 @@ sfloat_repr(PyArray_SFloatDescr *self)
static PyArray_DTypeMeta PyArray_SFloatDType = {{{
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "numpy._ScaledFloatTestDType",
- .tp_methods = sfloat_methods,
- .tp_new = sfloat_new,
+ .tp_basicsize = sizeof(PyArray_SFloatDescr),
.tp_repr = (reprfunc)sfloat_repr,
.tp_str = (reprfunc)sfloat_repr,
- .tp_basicsize = sizeof(PyArray_SFloatDescr),
+ .tp_methods = sfloat_methods,
+ .tp_new = sfloat_new,
}},
.type_num = -1,
.scalar_type = NULL,
- .flags = NPY_DT_PARAMETRIC,
+ .flags = NPY_DT_PARAMETRIC | NPY_DT_NUMERIC,
.dt_slots = &sfloat_slots,
};
@@ -440,7 +447,7 @@ sfloat_to_bool_resolve_descriptors(
static int
-init_casts(void)
+sfloat_init_casts(void)
{
PyArray_DTypeMeta *dtypes[2] = {&PyArray_SFloatDType, &PyArray_SFloatDType};
PyType_Slot slots[4] = {{0, NULL}};
@@ -448,11 +455,11 @@ init_casts(void)
.name = "sfloat_to_sfloat_cast",
.nin = 1,
.nout = 1,
+ /* minimal guaranteed casting */
+ .casting = NPY_SAME_KIND_CASTING,
.flags = NPY_METH_SUPPORTS_UNALIGNED,
.dtypes = dtypes,
.slots = slots,
- /* minimal guaranteed casting */
- .casting = NPY_SAME_KIND_CASTING,
};
slots[0].slot = NPY_METH_resolve_descriptors;
@@ -646,13 +653,55 @@ add_sfloats_resolve_descriptors(
}
+/*
+ * We define the hypot loop using the "PyUFunc_AddWrappingLoop" API.
+ * We use this very narrowly for mapping to the double hypot loop currently.
+ */
static int
-add_loop(const char *ufunc_name,
- PyArray_DTypeMeta *dtypes[3], PyObject *meth_or_promoter)
+translate_given_descrs_to_double(
+ int nin, int nout, PyArray_DTypeMeta *wrapped_dtypes[],
+ PyArray_Descr *given_descrs[], PyArray_Descr *new_descrs[])
+{
+ assert(nin == 2 && nout == 1);
+ for (int i = 0; i < 3; i++) {
+ if (given_descrs[i] == NULL) {
+ new_descrs[i] = NULL;
+ }
+ else {
+ new_descrs[i] = PyArray_DescrFromType(NPY_DOUBLE);
+ }
+ }
+ return 0;
+}
+
+
+static int
+translate_loop_descrs(
+ int nin, int nout, PyArray_DTypeMeta *new_dtypes[],
+ PyArray_Descr *given_descrs[],
+ PyArray_Descr *NPY_UNUSED(original_descrs[]),
+ PyArray_Descr *loop_descrs[])
+{
+ assert(nin == 2 && nout == 1);
+ loop_descrs[0] = sfloat_common_instance(
+ given_descrs[0], given_descrs[1]);
+ if (loop_descrs[0] == 0) {
+ return -1;
+ }
+ Py_INCREF(loop_descrs[0]);
+ loop_descrs[1] = loop_descrs[0];
+ Py_INCREF(loop_descrs[0]);
+ loop_descrs[2] = loop_descrs[0];
+ return 0;
+}
+
+
+static PyObject *
+sfloat_get_ufunc(const char *ufunc_name)
{
PyObject *mod = PyImport_ImportModule("numpy");
if (mod == NULL) {
- return -1;
+ return NULL;
}
PyObject *ufunc = PyObject_GetAttrString(mod, ufunc_name);
Py_DECREF(mod);
@@ -660,6 +709,18 @@ add_loop(const char *ufunc_name,
Py_DECREF(ufunc);
PyErr_Format(PyExc_TypeError,
"numpy.%s was not a ufunc!", ufunc_name);
+ return NULL;
+ }
+ return ufunc;
+}
+
+
+static int
+sfloat_add_loop(const char *ufunc_name,
+ PyArray_DTypeMeta *dtypes[3], PyObject *meth_or_promoter)
+{
+ PyObject *ufunc = sfloat_get_ufunc(ufunc_name);
+ if (ufunc == NULL) {
return -1;
}
PyObject *dtype_tup = PyArray_TupleFromItems(3, (PyObject **)dtypes, 1);
@@ -680,6 +741,24 @@ add_loop(const char *ufunc_name,
}
+static int
+sfloat_add_wrapping_loop(const char *ufunc_name, PyArray_DTypeMeta *dtypes[3])
+{
+ PyObject *ufunc = sfloat_get_ufunc(ufunc_name);
+ if (ufunc == NULL) {
+ return -1;
+ }
+ PyArray_DTypeMeta *double_dt = PyArray_DTypeFromTypeNum(NPY_DOUBLE);
+ PyArray_DTypeMeta *wrapped_dtypes[3] = {double_dt, double_dt, double_dt};
+ int res = PyUFunc_AddWrappingLoop(
+ ufunc, dtypes, wrapped_dtypes, &translate_given_descrs_to_double,
+ &translate_loop_descrs);
+ Py_DECREF(ufunc);
+ Py_DECREF(double_dt);
+
+ return res;
+}
+
/*
* We add some very basic promoters to allow multiplying normal and scaled
@@ -707,7 +786,7 @@ promote_to_sfloat(PyUFuncObject *NPY_UNUSED(ufunc),
* get less so with the introduction of public API).
*/
static int
-init_ufuncs(void) {
+sfloat_init_ufuncs(void) {
PyArray_DTypeMeta *dtypes[3] = {
&PyArray_SFloatDType, &PyArray_SFloatDType, &PyArray_SFloatDType};
PyType_Slot slots[3] = {{0, NULL}};
@@ -728,7 +807,7 @@ init_ufuncs(void) {
if (bmeth == NULL) {
return -1;
}
- int res = add_loop("multiply",
+ int res = sfloat_add_loop("multiply",
bmeth->dtypes, (PyObject *)bmeth->method);
Py_DECREF(bmeth);
if (res < 0) {
@@ -746,13 +825,18 @@ init_ufuncs(void) {
if (bmeth == NULL) {
return -1;
}
- res = add_loop("add",
+ res = sfloat_add_loop("add",
bmeth->dtypes, (PyObject *)bmeth->method);
Py_DECREF(bmeth);
if (res < 0) {
return -1;
}
+ /* N.B.: Wrapping isn't actually correct if scaling can be negative */
+ if (sfloat_add_wrapping_loop("hypot", dtypes) < 0) {
+ return -1;
+ }
+
/*
* Add a promoter for both directions of multiply with double.
*/
@@ -767,14 +851,14 @@ init_ufuncs(void) {
if (promoter == NULL) {
return -1;
}
- res = add_loop("multiply", promoter_dtypes, promoter);
+ res = sfloat_add_loop("multiply", promoter_dtypes, promoter);
if (res < 0) {
Py_DECREF(promoter);
return -1;
}
promoter_dtypes[0] = double_DType;
promoter_dtypes[1] = &PyArray_SFloatDType;
- res = add_loop("multiply", promoter_dtypes, promoter);
+ res = sfloat_add_loop("multiply", promoter_dtypes, promoter);
Py_DECREF(promoter);
if (res < 0) {
return -1;
@@ -815,11 +899,11 @@ get_sfloat_dtype(PyObject *NPY_UNUSED(mod), PyObject *NPY_UNUSED(args))
return NULL;
}
- if (init_casts() < 0) {
+ if (sfloat_init_casts() < 0) {
return NULL;
}
- if (init_ufuncs() < 0) {
+ if (sfloat_init_ufuncs() < 0) {
return NULL;
}