diff options
Diffstat (limited to 'numpy/core/src/ufuncobject.c')
-rw-r--r-- | numpy/core/src/ufuncobject.c | 132 |
1 files changed, 84 insertions, 48 deletions
diff --git a/numpy/core/src/ufuncobject.c b/numpy/core/src/ufuncobject.c index 25d84d8fd..a4e284ccc 100644 --- a/numpy/core/src/ufuncobject.c +++ b/numpy/core/src/ufuncobject.c @@ -625,52 +625,66 @@ select_types(PyUFuncObject *self, int *arg_types, PyUFuncGenericFunction *function, void **data, PyArray_SCALARKIND *scalars) { - int i=0, j; char start_type; - - if (PyTypeNum_ISUSERDEF((arg_types[0]))) { - PyObject *key, *obj; + + if (self->userloops) { + int userdef=-1; for (i=0; i<self->nin; i++) { - if (arg_types[i] != arg_types[0]) { - PyErr_SetString(PyExc_TypeError, - "ufuncs on user defined" \ - " types don't support "\ - "coercion"); - return -1; + if (PyTypeNum_ISUSERDEF(arg_types[i])) { + userdef = arg_types[i]; + break; } } - for (i=self->nin; i<self->nargs; i++) { - arg_types[i] = arg_types[0]; - } - - obj = NULL; - if (self->userloops) { - key = PyInt_FromLong((long) arg_types[0]); + if (userdef > 0) { + PyObject *key, *obj; + int *this_types=NULL; + + obj = NULL; + key = PyInt_FromLong((long) userdef); if (key == NULL) return -1; obj = PyDict_GetItem(self->userloops, key); Py_DECREF(key); + if (obj == NULL) { + PyErr_SetString(PyExc_TypeError, + "user-defined type used in ufunc" \ + " with no registered loops"); + return -1; + } + if PyTuple_Check(obj) { + PyObject *item; + *function = (PyUFuncGenericFunction) \ + PyCObject_AsVoidPtr(PyTuple_GET_ITEM(obj, + 0)); + item = PyTuple_GET_ITEM(obj, 2); + if (PyCObject_Check(item)) { + *data = PyCObject_AsVoidPtr(item); + } + item = PyTuple_GET_ITEM(obj, 1); + if (PyCObject_Check(item)) { + this_types = PyCObject_AsVoidPtr(item); + } + } + else { + *function = (PyUFuncGenericFunction) \ + PyCObject_AsVoidPtr(obj); + *data = NULL; + } + + if (this_types == NULL) { + for (i=1; i<self->nargs; i++) { + arg_types[i] = userdef; + } + } + else { + for (i=1; i<self->nargs; i++) { + arg_types[i] = this_types[i]; + } + } + Py_DECREF(obj); + return 0; } - if (obj == NULL) { - PyErr_SetString(PyExc_TypeError, - "no registered loop for this " \ - "user-defined type"); - return -1; - } - if PyTuple_Check(obj) { - *function = (PyUFuncGenericFunction) \ - PyCObject_AsVoidPtr(PyTuple_GET_ITEM(obj, 0)); - *data = PyCObject_AsVoidPtr(PyTuple_GET_ITEM(obj, 1)); - } - else { - *function = (PyUFuncGenericFunction) \ - PyCObject_AsVoidPtr(obj); - *data = NULL; - } - Py_DECREF(obj); - return 0; } - start_type = arg_types[0]; /* If the first argument is a scalar we need to place @@ -1210,13 +1224,15 @@ construct_matrices(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps) scntcast += descr->elsize; if (i < self->nin) { loop->cast[i] = \ - mps[i]->descr->f->cast[arg_types[i]]; + PyArray_GetCastFunc(mps[i]->descr, + arg_types[i]); } else { - loop->cast[i] = descr->f-> \ - cast[mps[i]->descr->type_num]; + loop->cast[i] = PyArray_GetCastFunc \ + (descr, mps[i]->descr->type_num); } Py_DECREF(descr); + if (!loop->cast[i]) return -1; } loop->swap[i] = !(PyArray_ISNOTSWAPPED(mps[i])); if (loop->steps[i]) @@ -2993,6 +3009,7 @@ static int PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc, int usertype, PyUFuncGenericFunction function, + int *arg_types, void *data) { PyArray_Descr *descr; @@ -3002,7 +3019,7 @@ PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc, descr=PyArray_DescrFromType(usertype); if ((usertype < PyArray_USERDEF) || (descr==NULL)) { PyErr_SetString(PyExc_TypeError, - "unknown type"); + "unknown user-defined type"); return -1; } Py_DECREF(descr); @@ -3014,21 +3031,40 @@ PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc, if (key == NULL) return -1; cobj = PyCObject_FromVoidPtr((void *)function, NULL); if (cobj == NULL) {Py_DECREF(key); return -1;} - if (data == NULL) { + if (data == NULL && arg_types == NULL) { ret = PyDict_SetItem(ufunc->userloops, key, cobj); Py_DECREF(cobj); Py_DECREF(key); return ret; } else { - PyObject *cobj2, *tmp; - cobj2 = PyCObject_FromVoidPtr(data, NULL); - if (cobj2 == NULL) { - Py_DECREF(cobj); - Py_DECREF(key); - return -1; + PyObject *cobj2, *cobj3, *tmp; + if (arg_types == NULL) { + cobj2 = Py_None; + Py_INCREF(cobj2); + } + else { + cobj2 = PyCObject_FromVoidPtr((void *)arg_types, NULL); + if (cobj2 == NULL) { + Py_DECREF(cobj); + Py_DECREF(key); + return -1; + } + } + if (data == NULL) { + cobj3 = Py_None; + Py_INCREF(cobj3); + } + else { + cobj3 = PyCObject_FromVoidPtr(data, NULL); + if (cobj3 == NULL) { + Py_DECREF(cobj2); + Py_DECREF(cobj); + Py_DECREF(key); + return -1; + } } - tmp=Py_BuildValue("NN", cobj, cobj2); + tmp=Py_BuildValue("NNN", cobj, cobj2, cobj3); ret = PyDict_SetItem(ufunc->userloops, key, tmp); Py_DECREF(tmp); Py_DECREF(key); |