diff options
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r-- | numpy/core/src/arrayobject.c | 55 |
1 files changed, 41 insertions, 14 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 9fdc67ee5..9fbffb3c3 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -6101,7 +6101,8 @@ PyArray_ValidType(int type) /* If the output is not a CARRAY, then it is buffered also */ static int -_bufferedcast(PyArrayObject *out, PyArrayObject *in) +_bufferedcast(PyArrayObject *out, PyArrayObject *in, + PyArray_VectorUnaryFunc *castfunc) { char *inbuffer, *bptr, *optr; char *outbuffer=NULL; @@ -6114,12 +6115,10 @@ _bufferedcast(PyArrayObject *out, PyArrayObject *in) int inswap, outswap=0; int obuf=!PyArray_ISCARRAY(out); int oelsize = out->descr->elsize; - PyArray_VectorUnaryFunc *castfunc; PyArray_CopySwapFunc *in_csn; PyArray_CopySwapFunc *out_csn; int retval = -1; - castfunc = in->descr->f->cast[out->descr->type_num]; in_csn = in->descr->f->copyswap; out_csn = out->descr->f->copyswap; @@ -6252,6 +6251,39 @@ PyArray_CastToType(PyArrayObject *mp, PyArray_Descr *at, int fortran) } +/*OBJECT_API + Get a cast function to cast from the input descriptor to the + output type_number (must be a registered data-type). + Returns NULL if un-successful. +*/ +static PyArray_VectorUnaryFunc * +PyArray_GetCastFunc(PyArray_Descr *descr, int type_num) +{ + PyArray_VectorUnaryFunc *castfunc; + if (type_num >= PyArray_NTYPES) { + PyObject *obj = descr->f->castdict; + if (obj && PyDict_Check(obj)) { + PyObject *key; + PyObject *cobj; + key = PyInt_FromLong(type_num); + cobj = PyDict_GetItem(obj, key); + Py_DECREF(key); + if (PyCObject_Check(cobj)) { + castfunc = PyCObject_AsVoidPtr(cobj); + } + } + } + else { + castfunc = descr->f->cast[type_num]; + } + + if (castfunc) return castfunc; + + PyErr_SetString(PyExc_ValueError, + "No cast function available."); + return NULL; +} + /* The number of elements in out must be an integer multiple of the number of elements in mp. */ @@ -6266,6 +6298,7 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp) int simple; intp mpsize = PyArray_SIZE(mp); intp outsize = PyArray_SIZE(out); + PyArray_VectorUnaryFunc *castfunc=NULL; if (mpsize == 0) return 0; if (!PyArray_ISWRITEABLE(out)) { @@ -6281,12 +6314,9 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp) return -1; } - if (out->descr->type_num >= PyArray_NTYPES) { - PyErr_SetString(PyExc_ValueError, - "Can only cast to builtin types."); - return -1; - - } + castfunc = PyArray_GetCastFunc(mp->descr, out->descr->type_num); + + if (castfunc == NULL) return -1; simple = ((PyArray_ISCARRAY_RO(mp) && PyArray_ISCARRAY(out)) || \ (PyArray_ISFARRAY_RO(mp) && PyArray_ISFARRAY(out))); @@ -6299,17 +6329,14 @@ PyArray_CastTo(PyArrayObject *out, PyArrayObject *mp) while(ncopies--) { inptr = mp->data; - mp->descr->f->cast[out->descr->type_num](inptr, - optr, - mpsize, - mp, out); + castfunc(inptr, optr, mpsize, mp, out); optr += obytes; } return 0; } /* If not a well-behaved cast, then use buffers */ - if (_bufferedcast(out, mp) == -1) { + if (_bufferedcast(out, mp, castfunc) == -1) { return -1; } return 0; |