summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/code_generators/array_api_order.txt1
-rw-r--r--numpy/core/include/numpy/arrayobject.h6
-rw-r--r--numpy/core/src/arrayobject.c55
-rw-r--r--numpy/core/src/arraytypes.inc.src6
-rw-r--r--numpy/core/src/ufuncobject.c7
5 files changed, 56 insertions, 19 deletions
diff --git a/numpy/core/code_generators/array_api_order.txt b/numpy/core/code_generators/array_api_order.txt
index 2491b5ab8..d088b6022 100644
--- a/numpy/core/code_generators/array_api_order.txt
+++ b/numpy/core/code_generators/array_api_order.txt
@@ -26,6 +26,7 @@ PyArray_CastScalarToCtype
PyArray_CastScalarDirect
PyArray_ScalarFromObject
PyArray_RegisterDataType
+PyArray_GetCastFunc
PyArray_FromDims
PyArray_FromDimsAndDataAndDescr
PyArray_FromAny
diff --git a/numpy/core/include/numpy/arrayobject.h b/numpy/core/include/numpy/arrayobject.h
index 12a78918d..d337906ca 100644
--- a/numpy/core/include/numpy/arrayobject.h
+++ b/numpy/core/include/numpy/arrayobject.h
@@ -877,6 +877,12 @@ typedef struct {
PyArray_SortFunc *sort[PyArray_NSORTS];
PyArray_ArgSortFunc *argsort[PyArray_NSORTS];
+ /* Dictionary of additional casting functions
+ PyArray_VectorUnaryFuncs
+ which can be populated to support casting
+ to other registered types */
+ PyObject *castdict;
+
} PyArray_ArrFuncs;
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 9fdc67ee5..9fbffb3c3 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -6101,7 +6101,8 @@ PyArray_ValidType(int type)
/* If the output is not a CARRAY, then it is buffered also */
static int
-_bufferedcast(PyArrayObject *out, PyArrayObject *in)
+_bufferedcast(PyArrayObject *out, PyArrayObject *in,
+ PyArray_VectorUnaryFunc *castfunc)
{
char *inbuffer, *bptr, *optr;
char *outbuffer=NULL;
@@ -6114,12 +6115,10 @@ _bufferedcast(PyArrayObject *out, PyArrayObject *in)
int inswap, outswap=0;
int obuf=!PyArray_ISCARRAY(out);
int oelsize = out->descr->elsize;
- PyArray_VectorUnaryFunc *castfunc;
PyArray_CopySwapFunc *in_csn;
PyArray_CopySwapFunc *out_csn;
int retval = -1;
- castfunc = in->descr->f->cast[out->descr->type_num];
in_csn = in->descr->f->copyswap;
out_csn = out->descr->f->copyswap;
@@ -6252,6 +6251,39 @@ PyArray_CastToType(PyArrayObject *mp, PyArray_Descr *at, int fortran)
}
+/*OBJECT_API
+ Get a cast function to cast from the input descriptor to the
+ output type_number (must be a registered data-type).
+ Returns NULL if un-successful.
+*/
+static PyArray_VectorUnaryFunc *
+PyArray_GetCastFunc(PyArray_Descr *descr, int type_num)
+{
+ PyArray_VectorUnaryFunc *castfunc;
+ if (type_num >= PyArray_NTYPES) {
+ PyObject *obj = descr->f->castdict;
+ if (obj && PyDict_Check(obj)) {
+ PyObject *key;
+ PyObject *cobj;
+ key = PyInt_FromLong(type_num);
+ cobj = PyDict_GetItem(obj, key);
+ Py_DECREF(key);
+ if (PyCObject_Check(cobj)) {
+ castfunc = PyCObject_AsVoidPtr(cobj);
+ }
+ }
+ }
+ else {
+ castfunc = descr->f->cast[type_num];
+ }
+
+ if (castfunc) return castfunc;
+
+ PyErr_SetString(PyExc_ValueError,
+ "No cast function available.");
+ return NULL;
+}
+
/* The number of elements in out must be an integer multiple
of the number of elements in mp.
*/
@@ -6266,6 +6298,7 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp)
int simple;
intp mpsize = PyArray_SIZE(mp);
intp outsize = PyArray_SIZE(out);
+ PyArray_VectorUnaryFunc *castfunc=NULL;
if (mpsize == 0) return 0;
if (!PyArray_ISWRITEABLE(out)) {
@@ -6281,12 +6314,9 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp)
return -1;
}
- if (out->descr->type_num >= PyArray_NTYPES) {
- PyErr_SetString(PyExc_ValueError,
- "Can only cast to builtin types.");
- return -1;
-
- }
+ castfunc = PyArray_GetCastFunc(mp->descr, out->descr->type_num);
+
+ if (castfunc == NULL) return -1;
simple = ((PyArray_ISCARRAY_RO(mp) && PyArray_ISCARRAY(out)) || \
(PyArray_ISFARRAY_RO(mp) && PyArray_ISFARRAY(out)));
@@ -6299,17 +6329,14 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp)
while(ncopies--) {
inptr = mp->data;
- mp->descr->f->cast[out->descr->type_num](inptr,
- optr,
- mpsize,
- mp, out);
+ castfunc(inptr, optr, mpsize, mp, out);
optr += obytes;
}
return 0;
}
/* If not a well-behaved cast, then use buffers */
- if (_bufferedcast(out, mp) == -1) {
+ if (_bufferedcast(out, mp, castfunc) == -1) {
return -1;
}
return 0;
diff --git a/numpy/core/src/arraytypes.inc.src b/numpy/core/src/arraytypes.inc.src
index b4c02d6a1..f3819e0b6 100644
--- a/numpy/core/src/arraytypes.inc.src
+++ b/numpy/core/src/arraytypes.inc.src
@@ -1859,7 +1859,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
},
{
NULL, NULL, NULL, NULL
- }
+ },
+ NULL
};
static PyArray_Descr @from@_Descr = {
@@ -1930,7 +1931,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
},
{
NULL, NULL, NULL, NULL
- }
+ },
+ NULL
};
static PyArray_Descr @from@_Descr = {
diff --git a/numpy/core/src/ufuncobject.c b/numpy/core/src/ufuncobject.c
index 3b10676c7..25d84d8fd 100644
--- a/numpy/core/src/ufuncobject.c
+++ b/numpy/core/src/ufuncobject.c
@@ -1209,7 +1209,7 @@ construct_matrices(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps)
else
scntcast += descr->elsize;
if (i < self->nin) {
- loop->cast[i] = \
+ loop->cast[i] = \
mps[i]->descr->f->cast[arg_types[i]];
}
else {
@@ -1935,8 +1935,9 @@ construct_reduce(PyUFuncObject *self, PyArrayObject **arr, int axis,
if (loop->obj) memset(loop->buffer, 0, _size);
loop->castbuf = loop->buffer + \
loop->bufsize*aar->descr->elsize;
- loop->bufptr[0] = loop->castbuf;
- loop->cast = aar->descr->f->cast[otype];
+ loop->bufptr[0] = loop->castbuf;
+ loop->cast = PyArray_GetCastFunc(aar->descr, otype);
+ if (loop->cast == NULL) goto fail;
}
else {
_size = loop->bufsize * loop->outsize;