diff options
Diffstat (limited to 'numpy/core/src/umath/_scaled_float_dtype.c')
-rw-r--r-- | numpy/core/src/umath/_scaled_float_dtype.c | 128 |
1 files changed, 106 insertions, 22 deletions
diff --git a/numpy/core/src/umath/_scaled_float_dtype.c b/numpy/core/src/umath/_scaled_float_dtype.c index a214b32aa..c26ace9f1 100644 --- a/numpy/core/src/umath/_scaled_float_dtype.c +++ b/numpy/core/src/umath/_scaled_float_dtype.c @@ -26,6 +26,14 @@ #include "dispatching.h" +/* TODO: from wrapping_array_method.c, use proper public header eventually */ +NPY_NO_EXPORT int +PyUFunc_AddWrappingLoop(PyObject *ufunc_obj, + PyArray_DTypeMeta *new_dtypes[], PyArray_DTypeMeta *wrapped_dtypes[], + translate_given_descrs_func *translate_given_descrs, + translate_loop_descrs_func *translate_loop_descrs); + + typedef struct { PyArray_Descr base; double scaling; @@ -125,9 +133,9 @@ sfloat_setitem(PyObject *obj, char *data, PyArrayObject *arr) /* Special DType methods and the descr->f slot storage */ NPY_DType_Slots sfloat_slots = { - .default_descr = &sfloat_default_descr, .discover_descr_from_pyobject = &sfloat_discover_from_pyobject, .is_known_scalar_type = &sfloat_is_known_scalar_type, + .default_descr = &sfloat_default_descr, .common_dtype = &sfloat_common_dtype, .common_instance = &sfloat_common_instance, .f = { @@ -136,14 +144,13 @@ NPY_DType_Slots sfloat_slots = { } }; - static PyArray_SFloatDescr SFloatSingleton = {{ - .elsize = sizeof(double), - .alignment = _ALIGN(double), + .byteorder = '|', /* do not bother with byte-swapping... */ .flags = NPY_USE_GETITEM|NPY_USE_SETITEM, .type_num = -1, + .elsize = sizeof(double), + .alignment = _ALIGN(double), .f = &sfloat_slots.f, - .byteorder = '|', /* do not bother with byte-swapping... */ }, .scaling = 1, }; @@ -233,15 +240,15 @@ sfloat_repr(PyArray_SFloatDescr *self) static PyArray_DTypeMeta PyArray_SFloatDType = {{{ PyVarObject_HEAD_INIT(NULL, 0) .tp_name = "numpy._ScaledFloatTestDType", - .tp_methods = sfloat_methods, - .tp_new = sfloat_new, + .tp_basicsize = sizeof(PyArray_SFloatDescr), .tp_repr = (reprfunc)sfloat_repr, .tp_str = (reprfunc)sfloat_repr, - .tp_basicsize = sizeof(PyArray_SFloatDescr), + .tp_methods = sfloat_methods, + .tp_new = sfloat_new, }}, .type_num = -1, .scalar_type = NULL, - .flags = NPY_DT_PARAMETRIC, + .flags = NPY_DT_PARAMETRIC | NPY_DT_NUMERIC, .dt_slots = &sfloat_slots, }; @@ -440,7 +447,7 @@ sfloat_to_bool_resolve_descriptors( static int -init_casts(void) +sfloat_init_casts(void) { PyArray_DTypeMeta *dtypes[2] = {&PyArray_SFloatDType, &PyArray_SFloatDType}; PyType_Slot slots[4] = {{0, NULL}}; @@ -448,11 +455,11 @@ init_casts(void) .name = "sfloat_to_sfloat_cast", .nin = 1, .nout = 1, + /* minimal guaranteed casting */ + .casting = NPY_SAME_KIND_CASTING, .flags = NPY_METH_SUPPORTS_UNALIGNED, .dtypes = dtypes, .slots = slots, - /* minimal guaranteed casting */ - .casting = NPY_SAME_KIND_CASTING, }; slots[0].slot = NPY_METH_resolve_descriptors; @@ -646,13 +653,55 @@ add_sfloats_resolve_descriptors( } +/* + * We define the hypot loop using the "PyUFunc_AddWrappingLoop" API. + * We use this very narrowly for mapping to the double hypot loop currently. + */ static int -add_loop(const char *ufunc_name, - PyArray_DTypeMeta *dtypes[3], PyObject *meth_or_promoter) +translate_given_descrs_to_double( + int nin, int nout, PyArray_DTypeMeta *wrapped_dtypes[], + PyArray_Descr *given_descrs[], PyArray_Descr *new_descrs[]) +{ + assert(nin == 2 && nout == 1); + for (int i = 0; i < 3; i++) { + if (given_descrs[i] == NULL) { + new_descrs[i] = NULL; + } + else { + new_descrs[i] = PyArray_DescrFromType(NPY_DOUBLE); + } + } + return 0; +} + + +static int +translate_loop_descrs( + int nin, int nout, PyArray_DTypeMeta *new_dtypes[], + PyArray_Descr *given_descrs[], + PyArray_Descr *NPY_UNUSED(original_descrs[]), + PyArray_Descr *loop_descrs[]) +{ + assert(nin == 2 && nout == 1); + loop_descrs[0] = sfloat_common_instance( + given_descrs[0], given_descrs[1]); + if (loop_descrs[0] == 0) { + return -1; + } + Py_INCREF(loop_descrs[0]); + loop_descrs[1] = loop_descrs[0]; + Py_INCREF(loop_descrs[0]); + loop_descrs[2] = loop_descrs[0]; + return 0; +} + + +static PyObject * +sfloat_get_ufunc(const char *ufunc_name) { PyObject *mod = PyImport_ImportModule("numpy"); if (mod == NULL) { - return -1; + return NULL; } PyObject *ufunc = PyObject_GetAttrString(mod, ufunc_name); Py_DECREF(mod); @@ -660,6 +709,18 @@ add_loop(const char *ufunc_name, Py_DECREF(ufunc); PyErr_Format(PyExc_TypeError, "numpy.%s was not a ufunc!", ufunc_name); + return NULL; + } + return ufunc; +} + + +static int +sfloat_add_loop(const char *ufunc_name, + PyArray_DTypeMeta *dtypes[3], PyObject *meth_or_promoter) +{ + PyObject *ufunc = sfloat_get_ufunc(ufunc_name); + if (ufunc == NULL) { return -1; } PyObject *dtype_tup = PyArray_TupleFromItems(3, (PyObject **)dtypes, 1); @@ -680,6 +741,24 @@ add_loop(const char *ufunc_name, } +static int +sfloat_add_wrapping_loop(const char *ufunc_name, PyArray_DTypeMeta *dtypes[3]) +{ + PyObject *ufunc = sfloat_get_ufunc(ufunc_name); + if (ufunc == NULL) { + return -1; + } + PyArray_DTypeMeta *double_dt = PyArray_DTypeFromTypeNum(NPY_DOUBLE); + PyArray_DTypeMeta *wrapped_dtypes[3] = {double_dt, double_dt, double_dt}; + int res = PyUFunc_AddWrappingLoop( + ufunc, dtypes, wrapped_dtypes, &translate_given_descrs_to_double, + &translate_loop_descrs); + Py_DECREF(ufunc); + Py_DECREF(double_dt); + + return res; +} + /* * We add some very basic promoters to allow multiplying normal and scaled @@ -707,7 +786,7 @@ promote_to_sfloat(PyUFuncObject *NPY_UNUSED(ufunc), * get less so with the introduction of public API). */ static int -init_ufuncs(void) { +sfloat_init_ufuncs(void) { PyArray_DTypeMeta *dtypes[3] = { &PyArray_SFloatDType, &PyArray_SFloatDType, &PyArray_SFloatDType}; PyType_Slot slots[3] = {{0, NULL}}; @@ -728,7 +807,7 @@ init_ufuncs(void) { if (bmeth == NULL) { return -1; } - int res = add_loop("multiply", + int res = sfloat_add_loop("multiply", bmeth->dtypes, (PyObject *)bmeth->method); Py_DECREF(bmeth); if (res < 0) { @@ -746,13 +825,18 @@ init_ufuncs(void) { if (bmeth == NULL) { return -1; } - res = add_loop("add", + res = sfloat_add_loop("add", bmeth->dtypes, (PyObject *)bmeth->method); Py_DECREF(bmeth); if (res < 0) { return -1; } + /* N.B.: Wrapping isn't actually correct if scaling can be negative */ + if (sfloat_add_wrapping_loop("hypot", dtypes) < 0) { + return -1; + } + /* * Add a promoter for both directions of multiply with double. */ @@ -767,14 +851,14 @@ init_ufuncs(void) { if (promoter == NULL) { return -1; } - res = add_loop("multiply", promoter_dtypes, promoter); + res = sfloat_add_loop("multiply", promoter_dtypes, promoter); if (res < 0) { Py_DECREF(promoter); return -1; } promoter_dtypes[0] = double_DType; promoter_dtypes[1] = &PyArray_SFloatDType; - res = add_loop("multiply", promoter_dtypes, promoter); + res = sfloat_add_loop("multiply", promoter_dtypes, promoter); Py_DECREF(promoter); if (res < 0) { return -1; @@ -815,11 +899,11 @@ get_sfloat_dtype(PyObject *NPY_UNUSED(mod), PyObject *NPY_UNUSED(args)) return NULL; } - if (init_casts() < 0) { + if (sfloat_init_casts() < 0) { return NULL; } - if (init_ufuncs() < 0) { + if (sfloat_init_ufuncs() < 0) { return NULL; } |