summaryrefslogtreecommitdiff
path: root/numpy/core/src/ufuncobject.c
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src/ufuncobject.c')
-rw-r--r--numpy/core/src/ufuncobject.c132
1 files changed, 84 insertions, 48 deletions
diff --git a/numpy/core/src/ufuncobject.c b/numpy/core/src/ufuncobject.c
index 25d84d8fd..a4e284ccc 100644
--- a/numpy/core/src/ufuncobject.c
+++ b/numpy/core/src/ufuncobject.c
@@ -625,52 +625,66 @@ select_types(PyUFuncObject *self, int *arg_types,
PyUFuncGenericFunction *function, void **data,
PyArray_SCALARKIND *scalars)
{
-
int i=0, j;
char start_type;
-
- if (PyTypeNum_ISUSERDEF((arg_types[0]))) {
- PyObject *key, *obj;
+
+ if (self->userloops) {
+ int userdef=-1;
for (i=0; i<self->nin; i++) {
- if (arg_types[i] != arg_types[0]) {
- PyErr_SetString(PyExc_TypeError,
- "ufuncs on user defined" \
- " types don't support "\
- "coercion");
- return -1;
+ if (PyTypeNum_ISUSERDEF(arg_types[i])) {
+ userdef = arg_types[i];
+ break;
}
}
- for (i=self->nin; i<self->nargs; i++) {
- arg_types[i] = arg_types[0];
- }
-
- obj = NULL;
- if (self->userloops) {
- key = PyInt_FromLong((long) arg_types[0]);
+ if (userdef > 0) {
+ PyObject *key, *obj;
+ int *this_types=NULL;
+
+ obj = NULL;
+ key = PyInt_FromLong((long) userdef);
if (key == NULL) return -1;
obj = PyDict_GetItem(self->userloops, key);
Py_DECREF(key);
+ if (obj == NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "user-defined type used in ufunc" \
+ " with no registered loops");
+ return -1;
+ }
+ if PyTuple_Check(obj) {
+ PyObject *item;
+ *function = (PyUFuncGenericFunction) \
+ PyCObject_AsVoidPtr(PyTuple_GET_ITEM(obj,
+ 0));
+ item = PyTuple_GET_ITEM(obj, 2);
+ if (PyCObject_Check(item)) {
+ *data = PyCObject_AsVoidPtr(item);
+ }
+ item = PyTuple_GET_ITEM(obj, 1);
+ if (PyCObject_Check(item)) {
+ this_types = PyCObject_AsVoidPtr(item);
+ }
+ }
+ else {
+ *function = (PyUFuncGenericFunction) \
+ PyCObject_AsVoidPtr(obj);
+ *data = NULL;
+ }
+
+ if (this_types == NULL) {
+ for (i=1; i<self->nargs; i++) {
+ arg_types[i] = userdef;
+ }
+ }
+ else {
+ for (i=1; i<self->nargs; i++) {
+ arg_types[i] = this_types[i];
+ }
+ }
+ Py_DECREF(obj);
+ return 0;
}
- if (obj == NULL) {
- PyErr_SetString(PyExc_TypeError,
- "no registered loop for this " \
- "user-defined type");
- return -1;
- }
- if PyTuple_Check(obj) {
- *function = (PyUFuncGenericFunction) \
- PyCObject_AsVoidPtr(PyTuple_GET_ITEM(obj, 0));
- *data = PyCObject_AsVoidPtr(PyTuple_GET_ITEM(obj, 1));
- }
- else {
- *function = (PyUFuncGenericFunction) \
- PyCObject_AsVoidPtr(obj);
- *data = NULL;
- }
- Py_DECREF(obj);
- return 0;
}
-
start_type = arg_types[0];
/* If the first argument is a scalar we need to place
@@ -1210,13 +1224,15 @@ construct_matrices(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps)
scntcast += descr->elsize;
if (i < self->nin) {
loop->cast[i] = \
- mps[i]->descr->f->cast[arg_types[i]];
+ PyArray_GetCastFunc(mps[i]->descr,
+ arg_types[i]);
}
else {
- loop->cast[i] = descr->f-> \
- cast[mps[i]->descr->type_num];
+ loop->cast[i] = PyArray_GetCastFunc \
+ (descr, mps[i]->descr->type_num);
}
Py_DECREF(descr);
+ if (!loop->cast[i]) return -1;
}
loop->swap[i] = !(PyArray_ISNOTSWAPPED(mps[i]));
if (loop->steps[i])
@@ -2993,6 +3009,7 @@ static int
PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc,
int usertype,
PyUFuncGenericFunction function,
+ int *arg_types,
void *data)
{
PyArray_Descr *descr;
@@ -3002,7 +3019,7 @@ PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc,
descr=PyArray_DescrFromType(usertype);
if ((usertype < PyArray_USERDEF) || (descr==NULL)) {
PyErr_SetString(PyExc_TypeError,
- "unknown type");
+ "unknown user-defined type");
return -1;
}
Py_DECREF(descr);
@@ -3014,21 +3031,40 @@ PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc,
if (key == NULL) return -1;
cobj = PyCObject_FromVoidPtr((void *)function, NULL);
if (cobj == NULL) {Py_DECREF(key); return -1;}
- if (data == NULL) {
+ if (data == NULL && arg_types == NULL) {
ret = PyDict_SetItem(ufunc->userloops, key, cobj);
Py_DECREF(cobj);
Py_DECREF(key);
return ret;
}
else {
- PyObject *cobj2, *tmp;
- cobj2 = PyCObject_FromVoidPtr(data, NULL);
- if (cobj2 == NULL) {
- Py_DECREF(cobj);
- Py_DECREF(key);
- return -1;
+ PyObject *cobj2, *cobj3, *tmp;
+ if (arg_types == NULL) {
+ cobj2 = Py_None;
+ Py_INCREF(cobj2);
+ }
+ else {
+ cobj2 = PyCObject_FromVoidPtr((void *)arg_types, NULL);
+ if (cobj2 == NULL) {
+ Py_DECREF(cobj);
+ Py_DECREF(key);
+ return -1;
+ }
+ }
+ if (data == NULL) {
+ cobj3 = Py_None;
+ Py_INCREF(cobj3);
+ }
+ else {
+ cobj3 = PyCObject_FromVoidPtr(data, NULL);
+ if (cobj3 == NULL) {
+ Py_DECREF(cobj2);
+ Py_DECREF(cobj);
+ Py_DECREF(key);
+ return -1;
+ }
}
- tmp=Py_BuildValue("NN", cobj, cobj2);
+ tmp=Py_BuildValue("NNN", cobj, cobj2, cobj3);
ret = PyDict_SetItem(ufunc->userloops, key, tmp);
Py_DECREF(tmp);
Py_DECREF(key);