diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/code_generators/array_api_order.txt | 1 | ||||
-rw-r--r-- | numpy/core/include/numpy/arrayobject.h | 6 | ||||
-rw-r--r-- | numpy/core/src/arrayobject.c | 55 | ||||
-rw-r--r-- | numpy/core/src/arraytypes.inc.src | 6 | ||||
-rw-r--r-- | numpy/core/src/ufuncobject.c | 7 |
5 files changed, 56 insertions, 19 deletions
diff --git a/numpy/core/code_generators/array_api_order.txt b/numpy/core/code_generators/array_api_order.txt index 2491b5ab8..d088b6022 100644 --- a/numpy/core/code_generators/array_api_order.txt +++ b/numpy/core/code_generators/array_api_order.txt @@ -26,6 +26,7 @@ PyArray_CastScalarToCtype PyArray_CastScalarDirect PyArray_ScalarFromObject PyArray_RegisterDataType +PyArray_GetCastFunc PyArray_FromDims PyArray_FromDimsAndDataAndDescr PyArray_FromAny diff --git a/numpy/core/include/numpy/arrayobject.h b/numpy/core/include/numpy/arrayobject.h index 12a78918d..d337906ca 100644 --- a/numpy/core/include/numpy/arrayobject.h +++ b/numpy/core/include/numpy/arrayobject.h @@ -877,6 +877,12 @@ typedef struct { PyArray_SortFunc *sort[PyArray_NSORTS]; PyArray_ArgSortFunc *argsort[PyArray_NSORTS]; + /* Dictionary of additional casting functions + PyArray_VectorUnaryFuncs + which can be populated to support casting + to other registered types */ + PyObject *castdict; + } PyArray_ArrFuncs; diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 9fdc67ee5..9fbffb3c3 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -6101,7 +6101,8 @@ PyArray_ValidType(int type) /* If the output is not a CARRAY, then it is buffered also */ static int -_bufferedcast(PyArrayObject *out, PyArrayObject *in) +_bufferedcast(PyArrayObject *out, PyArrayObject *in, + PyArray_VectorUnaryFunc *castfunc) { char *inbuffer, *bptr, *optr; char *outbuffer=NULL; @@ -6114,12 +6115,10 @@ _bufferedcast(PyArrayObject *out, PyArrayObject *in) int inswap, outswap=0; int obuf=!PyArray_ISCARRAY(out); int oelsize = out->descr->elsize; - PyArray_VectorUnaryFunc *castfunc; PyArray_CopySwapFunc *in_csn; PyArray_CopySwapFunc *out_csn; int retval = -1; - castfunc = in->descr->f->cast[out->descr->type_num]; in_csn = in->descr->f->copyswap; out_csn = out->descr->f->copyswap; @@ -6252,6 +6251,39 @@ PyArray_CastToType(PyArrayObject *mp, PyArray_Descr *at, int fortran) } +/*OBJECT_API + Get a cast function to cast from the input descriptor to the + output type_number (must be a registered data-type). + Returns NULL if un-successful. +*/ +static PyArray_VectorUnaryFunc * +PyArray_GetCastFunc(PyArray_Descr *descr, int type_num) +{ + PyArray_VectorUnaryFunc *castfunc; + if (type_num >= PyArray_NTYPES) { + PyObject *obj = descr->f->castdict; + if (obj && PyDict_Check(obj)) { + PyObject *key; + PyObject *cobj; + key = PyInt_FromLong(type_num); + cobj = PyDict_GetItem(obj, key); + Py_DECREF(key); + if (PyCObject_Check(cobj)) { + castfunc = PyCObject_AsVoidPtr(cobj); + } + } + } + else { + castfunc = descr->f->cast[type_num]; + } + + if (castfunc) return castfunc; + + PyErr_SetString(PyExc_ValueError, + "No cast function available."); + return NULL; +} + /* The number of elements in out must be an integer multiple of the number of elements in mp. */ @@ -6266,6 +6298,7 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp) int simple; intp mpsize = PyArray_SIZE(mp); intp outsize = PyArray_SIZE(out); + PyArray_VectorUnaryFunc *castfunc=NULL; if (mpsize == 0) return 0; if (!PyArray_ISWRITEABLE(out)) { @@ -6281,12 +6314,9 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp) return -1; } - if (out->descr->type_num >= PyArray_NTYPES) { - PyErr_SetString(PyExc_ValueError, - "Can only cast to builtin types."); - return -1; - - } + castfunc = PyArray_GetCastFunc(mp->descr, out->descr->type_num); + + if (castfunc == NULL) return -1; simple = ((PyArray_ISCARRAY_RO(mp) && PyArray_ISCARRAY(out)) || \ (PyArray_ISFARRAY_RO(mp) && PyArray_ISFARRAY(out))); @@ -6299,17 +6329,14 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp) while(ncopies--) { inptr = mp->data; - mp->descr->f->cast[out->descr->type_num](inptr, - optr, - mpsize, - mp, out); + castfunc(inptr, optr, mpsize, mp, out); optr += obytes; } return 0; } /* If not a well-behaved cast, then use buffers */ - if (_bufferedcast(out, mp) == -1) { + if (_bufferedcast(out, mp, castfunc) == -1) { return -1; } return 0; diff --git a/numpy/core/src/arraytypes.inc.src b/numpy/core/src/arraytypes.inc.src index b4c02d6a1..f3819e0b6 100644 --- a/numpy/core/src/arraytypes.inc.src +++ b/numpy/core/src/arraytypes.inc.src @@ -1859,7 +1859,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = { }, { NULL, NULL, NULL, NULL - } + }, + NULL }; static PyArray_Descr @from@_Descr = { @@ -1930,7 +1931,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = { }, { NULL, NULL, NULL, NULL - } + }, + NULL }; static PyArray_Descr @from@_Descr = { diff --git a/numpy/core/src/ufuncobject.c b/numpy/core/src/ufuncobject.c index 3b10676c7..25d84d8fd 100644 --- a/numpy/core/src/ufuncobject.c +++ b/numpy/core/src/ufuncobject.c @@ -1209,7 +1209,7 @@ construct_matrices(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps) else scntcast += descr->elsize; if (i < self->nin) { - loop->cast[i] = \ + loop->cast[i] = \ mps[i]->descr->f->cast[arg_types[i]]; } else { @@ -1935,8 +1935,9 @@ construct_reduce(PyUFuncObject *self, PyArrayObject **arr, int axis, if (loop->obj) memset(loop->buffer, 0, _size); loop->castbuf = loop->buffer + \ loop->bufsize*aar->descr->elsize; - loop->bufptr[0] = loop->castbuf; - loop->cast = aar->descr->f->cast[otype]; + loop->bufptr[0] = loop->castbuf; + loop->cast = PyArray_GetCastFunc(aar->descr, otype); + if (loop->cast == NULL) goto fail; } else { _size = loop->bufsize * loop->outsize; |