summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/_struct_ufunc_tests.c.src61
-rw-r--r--numpy/core/src/umath/ufunc_object.c13
-rw-r--r--numpy/core/tests/test_ufunc.py1
3 files changed, 64 insertions, 11 deletions
diff --git a/numpy/core/src/umath/_struct_ufunc_tests.c.src b/numpy/core/src/umath/_struct_ufunc_tests.c.src
index 5c6e235e0..3eaac73e1 100644
--- a/numpy/core/src/umath/_struct_ufunc_tests.c.src
+++ b/numpy/core/src/umath/_struct_ufunc_tests.c.src
@@ -17,12 +17,6 @@
* docs.python.org .
*/
-static PyMethodDef StructUfuncTestMethods[] = {
- {NULL, NULL, 0, NULL}
-};
-
-/* The loop definition must precede the PyMODINIT_FUNC. */
-
static void add_uint64_triplet(char **args, npy_intp *dimensions,
npy_intp* steps, void* data)
{
@@ -53,6 +47,59 @@ static void add_uint64_triplet(char **args, npy_intp *dimensions,
}
}
+static PyObject*
+register_fail(PyObject* NPY_UNUSED(self), PyObject* NPY_UNUSED(args))
+{
+ PyObject *add_triplet;
+ PyObject *dtype_dict;
+ PyArray_Descr *dtype;
+ PyArray_Descr *dtypes[3];
+ int retval;
+
+ add_triplet = PyUFunc_FromFuncAndData(NULL, NULL, NULL, 0, 2, 1,
+ PyUFunc_None, "add_triplet",
+ "add_triplet_docstring", 0);
+
+ dtype_dict = Py_BuildValue("[(s, s), (s, s), (s, s)]",
+ "f0", "u8", "f1", "u8", "f2", "u8");
+ PyArray_DescrConverter(dtype_dict, &dtype);
+ Py_DECREF(dtype_dict);
+
+ dtypes[0] = dtype;
+ dtypes[1] = dtype;
+ dtypes[2] = dtype;
+
+ retval = PyUFunc_RegisterLoopForDescr((PyUFuncObject *)add_triplet,
+ dtype,
+ &add_uint64_triplet,
+ dtypes,
+ NULL);
+
+ if (retval < 0) {
+ Py_DECREF(add_triplet);
+ Py_DECREF(dtype);
+ return NULL;
+ }
+ retval = PyUFunc_RegisterLoopForDescr((PyUFuncObject *)add_triplet,
+ dtype,
+ &add_uint64_triplet,
+ dtypes,
+ NULL);
+ Py_DECREF(add_triplet);
+ Py_DECREF(dtype);
+ if (retval < 0) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+static PyMethodDef StructUfuncTestMethods[] = {
+ {"register_fail",
+ register_fail,
+ METH_NOARGS, NULL},
+ {NULL, NULL, 0, NULL}
+};
+
#if defined(NPY_PY3K)
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
@@ -100,7 +147,7 @@ PyMODINIT_FUNC init_struct_ufunc_tests(void)
"add_triplet_docstring", 0);
dtype_dict = Py_BuildValue("[(s, s), (s, s), (s, s)]",
- "f0", "u8", "f1", "u8", "f2", "u8");
+ "f0", "u8", "f1", "u8", "f2", "u8");
PyArray_DescrConverter(dtype_dict, &dtype);
Py_DECREF(dtype_dict);
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index ab986caa1..f198a19bd 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -5110,11 +5110,14 @@ _loop1d_list_free(void *ptr)
* instead of dtype type num values. This allows a 1-d loop to be registered
* for a structured array dtype or a custom dtype. The ufunc is called
* whenever any of it's input arguments match the user_dtype argument.
- * ufunc - ufunc object created from call to PyUFunc_FromFuncAndData
+ *
+ * ufunc - ufunc object created from call to PyUFunc_FromFuncAndData
* user_dtype - dtype that ufunc will be registered with
- * function - 1-d loop function pointer
+ * function - 1-d loop function pointer
* arg_dtypes - array of dtype objects describing the ufunc operands
- * data - arbitrary data pointer passed in to loop function
+ * data - arbitrary data pointer passed in to loop function
+ *
+ * returns 0 on success, -1 for failure
*/
/*UFUNC_API*/
NPY_NO_EXPORT int
@@ -5178,7 +5181,7 @@ PyUFunc_RegisterLoopForDescr(PyUFuncObject *ufunc,
}
current = current->next;
}
- if (cmp == 0 && current->arg_dtypes == NULL) {
+ if (cmp == 0 && current != NULL && current->arg_dtypes == NULL) {
current->arg_dtypes = PyArray_malloc(ufunc->nargs *
sizeof(PyArray_Descr*));
if (arg_dtypes != NULL) {
@@ -5196,6 +5199,8 @@ PyUFunc_RegisterLoopForDescr(PyUFuncObject *ufunc,
current->nargs = ufunc->nargs;
}
else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "loop already registered");
result = -1;
}
}
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index fa62767f6..9e5e3fb77 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -1580,6 +1580,7 @@ class TestUfunc(object):
result = struct_ufunc.add_triplet(a, b)
assert_equal(result, np.array([(2, 4, 6)], dtype='u8,u8,u8'))
+ assert_raises(RuntimeError, struct_ufunc.register_fail)
def test_custom_ufunc(self):
a = np.array(