diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/umath/_struct_ufunc_tests.c.src | 61 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_ufunc.py | 1 |
3 files changed, 64 insertions, 11 deletions
diff --git a/numpy/core/src/umath/_struct_ufunc_tests.c.src b/numpy/core/src/umath/_struct_ufunc_tests.c.src index 5c6e235e0..3eaac73e1 100644 --- a/numpy/core/src/umath/_struct_ufunc_tests.c.src +++ b/numpy/core/src/umath/_struct_ufunc_tests.c.src @@ -17,12 +17,6 @@ * docs.python.org . */ -static PyMethodDef StructUfuncTestMethods[] = { - {NULL, NULL, 0, NULL} -}; - -/* The loop definition must precede the PyMODINIT_FUNC. */ - static void add_uint64_triplet(char **args, npy_intp *dimensions, npy_intp* steps, void* data) { @@ -53,6 +47,59 @@ static void add_uint64_triplet(char **args, npy_intp *dimensions, } } +static PyObject* +register_fail(PyObject* NPY_UNUSED(self), PyObject* NPY_UNUSED(args)) +{ + PyObject *add_triplet; + PyObject *dtype_dict; + PyArray_Descr *dtype; + PyArray_Descr *dtypes[3]; + int retval; + + add_triplet = PyUFunc_FromFuncAndData(NULL, NULL, NULL, 0, 2, 1, + PyUFunc_None, "add_triplet", + "add_triplet_docstring", 0); + + dtype_dict = Py_BuildValue("[(s, s), (s, s), (s, s)]", + "f0", "u8", "f1", "u8", "f2", "u8"); + PyArray_DescrConverter(dtype_dict, &dtype); + Py_DECREF(dtype_dict); + + dtypes[0] = dtype; + dtypes[1] = dtype; + dtypes[2] = dtype; + + retval = PyUFunc_RegisterLoopForDescr((PyUFuncObject *)add_triplet, + dtype, + &add_uint64_triplet, + dtypes, + NULL); + + if (retval < 0) { + Py_DECREF(add_triplet); + Py_DECREF(dtype); + return NULL; + } + retval = PyUFunc_RegisterLoopForDescr((PyUFuncObject *)add_triplet, + dtype, + &add_uint64_triplet, + dtypes, + NULL); + Py_DECREF(add_triplet); + Py_DECREF(dtype); + if (retval < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +static PyMethodDef StructUfuncTestMethods[] = { + {"register_fail", + register_fail, + METH_NOARGS, NULL}, + {NULL, NULL, 0, NULL} +}; + #if defined(NPY_PY3K) static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, @@ -100,7 +147,7 @@ PyMODINIT_FUNC init_struct_ufunc_tests(void) "add_triplet_docstring", 0); dtype_dict = Py_BuildValue("[(s, s), (s, s), (s, s)]", - "f0", "u8", "f1", "u8", "f2", "u8"); + "f0", "u8", "f1", "u8", "f2", "u8"); PyArray_DescrConverter(dtype_dict, &dtype); Py_DECREF(dtype_dict); diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index ab986caa1..f198a19bd 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -5110,11 +5110,14 @@ _loop1d_list_free(void *ptr) * instead of dtype type num values. This allows a 1-d loop to be registered * for a structured array dtype or a custom dtype. The ufunc is called * whenever any of it's input arguments match the user_dtype argument. - * ufunc - ufunc object created from call to PyUFunc_FromFuncAndData + * + * ufunc - ufunc object created from call to PyUFunc_FromFuncAndData * user_dtype - dtype that ufunc will be registered with - * function - 1-d loop function pointer + * function - 1-d loop function pointer * arg_dtypes - array of dtype objects describing the ufunc operands - * data - arbitrary data pointer passed in to loop function + * data - arbitrary data pointer passed in to loop function + * + * returns 0 on success, -1 for failure */ /*UFUNC_API*/ NPY_NO_EXPORT int @@ -5178,7 +5181,7 @@ PyUFunc_RegisterLoopForDescr(PyUFuncObject *ufunc, } current = current->next; } - if (cmp == 0 && current->arg_dtypes == NULL) { + if (cmp == 0 && current != NULL && current->arg_dtypes == NULL) { current->arg_dtypes = PyArray_malloc(ufunc->nargs * sizeof(PyArray_Descr*)); if (arg_dtypes != NULL) { @@ -5196,6 +5199,8 @@ PyUFunc_RegisterLoopForDescr(PyUFuncObject *ufunc, current->nargs = ufunc->nargs; } else { + PyErr_SetString(PyExc_RuntimeError, + "loop already registered"); result = -1; } } diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index fa62767f6..9e5e3fb77 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -1580,6 +1580,7 @@ class TestUfunc(object): result = struct_ufunc.add_triplet(a, b) assert_equal(result, np.array([(2, 4, 6)], dtype='u8,u8,u8')) + assert_raises(RuntimeError, struct_ufunc.register_fail) def test_custom_ufunc(self): a = np.array( |