summaryrefslogtreecommitdiff
path: root/numpy/core/src/arrayobject.c
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r--numpy/core/src/arrayobject.c55
1 files changed, 41 insertions, 14 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 9fdc67ee5..9fbffb3c3 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -6101,7 +6101,8 @@ PyArray_ValidType(int type)
/* If the output is not a CARRAY, then it is buffered also */
static int
-_bufferedcast(PyArrayObject *out, PyArrayObject *in)
+_bufferedcast(PyArrayObject *out, PyArrayObject *in,
+ PyArray_VectorUnaryFunc *castfunc)
{
char *inbuffer, *bptr, *optr;
char *outbuffer=NULL;
@@ -6114,12 +6115,10 @@ _bufferedcast(PyArrayObject *out, PyArrayObject *in)
int inswap, outswap=0;
int obuf=!PyArray_ISCARRAY(out);
int oelsize = out->descr->elsize;
- PyArray_VectorUnaryFunc *castfunc;
PyArray_CopySwapFunc *in_csn;
PyArray_CopySwapFunc *out_csn;
int retval = -1;
- castfunc = in->descr->f->cast[out->descr->type_num];
in_csn = in->descr->f->copyswap;
out_csn = out->descr->f->copyswap;
@@ -6252,6 +6251,39 @@ PyArray_CastToType(PyArrayObject *mp, PyArray_Descr *at, int fortran)
}
+/*OBJECT_API
+ Get a cast function to cast from the input descriptor to the
+ output type_number (must be a registered data-type).
+ Returns NULL if un-successful.
+*/
+static PyArray_VectorUnaryFunc *
+PyArray_GetCastFunc(PyArray_Descr *descr, int type_num)
+{
+ PyArray_VectorUnaryFunc *castfunc;
+ if (type_num >= PyArray_NTYPES) {
+ PyObject *obj = descr->f->castdict;
+ if (obj && PyDict_Check(obj)) {
+ PyObject *key;
+ PyObject *cobj;
+ key = PyInt_FromLong(type_num);
+ cobj = PyDict_GetItem(obj, key);
+ Py_DECREF(key);
+ if (PyCObject_Check(cobj)) {
+ castfunc = PyCObject_AsVoidPtr(cobj);
+ }
+ }
+ }
+ else {
+ castfunc = descr->f->cast[type_num];
+ }
+
+ if (castfunc) return castfunc;
+
+ PyErr_SetString(PyExc_ValueError,
+ "No cast function available.");
+ return NULL;
+}
+
/* The number of elements in out must be an integer multiple
of the number of elements in mp.
*/
@@ -6266,6 +6298,7 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp)
int simple;
intp mpsize = PyArray_SIZE(mp);
intp outsize = PyArray_SIZE(out);
+ PyArray_VectorUnaryFunc *castfunc=NULL;
if (mpsize == 0) return 0;
if (!PyArray_ISWRITEABLE(out)) {
@@ -6281,12 +6314,9 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp)
return -1;
}
- if (out->descr->type_num >= PyArray_NTYPES) {
- PyErr_SetString(PyExc_ValueError,
- "Can only cast to builtin types.");
- return -1;
-
- }
+ castfunc = PyArray_GetCastFunc(mp->descr, out->descr->type_num);
+
+ if (castfunc == NULL) return -1;
simple = ((PyArray_ISCARRAY_RO(mp) && PyArray_ISCARRAY(out)) || \
(PyArray_ISFARRAY_RO(mp) && PyArray_ISFARRAY(out)));
@@ -6299,17 +6329,14 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp)
while(ncopies--) {
inptr = mp->data;
- mp->descr->f->cast[out->descr->type_num](inptr,
- optr,
- mpsize,
- mp, out);
+ castfunc(inptr, optr, mpsize, mp, out);
optr += obytes;
}
return 0;
}
/* If not a well-behaved cast, then use buffers */
- if (_bufferedcast(out, mp) == -1) {
+ if (_bufferedcast(out, mp, castfunc) == -1) {
return -1;
}
return 0;