diff options
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 9351b5fc7..762563eb5 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -2203,9 +2203,13 @@ PyArray_Nonzero(PyArrayObject *self) PyArrayObject *ret = NULL; PyObject *ret_tuple; npy_intp ret_dims[2]; - PyArray_NonzeroFunc *nonzero = PyArray_DESCR(self)->f->nonzero; + + PyArray_NonzeroFunc *nonzero; + PyArray_Descr *dtype; + npy_intp nonzero_count; npy_intp added_count = 0; + int needs_api; int is_bool; NpyIter *iter; @@ -2213,6 +2217,10 @@ PyArray_Nonzero(PyArrayObject *self) NpyIter_GetMultiIndexFunc *get_multi_index; char **dataptr; + dtype = PyArray_DESCR(self); + nonzero = dtype->f->nonzero; + needs_api = PyDataType_FLAGCHK(dtype, NPY_NEEDS_PYAPI); + /* Special case - nonzero(zero_d) is nonzero(atleast_1d(zero_d)) */ if (ndim == 0) { char const* msg; @@ -2283,7 +2291,9 @@ PyArray_Nonzero(PyArrayObject *self) goto finish; } - NPY_BEGIN_THREADS_THRESHOLDED(count); + if (!needs_api) { + NPY_BEGIN_THREADS_THRESHOLDED(count); + } /* avoid function call for bool */ if (is_bool) { @@ -2324,6 +2334,9 @@ PyArray_Nonzero(PyArrayObject *self) } *multi_index++ = j; } + if (needs_api && PyErr_Occurred()) { + break; + } data += stride; } } @@ -2364,6 +2377,8 @@ PyArray_Nonzero(PyArrayObject *self) Py_DECREF(ret); return NULL; } + + needs_api = NpyIter_IterationNeedsAPI(iter); NPY_BEGIN_THREADS_NDITER(iter); @@ -2390,6 +2405,9 @@ PyArray_Nonzero(PyArrayObject *self) get_multi_index(iter, multi_index); multi_index += ndim; } + if (needs_api && PyErr_Occurred()) { + break; + } } while(iternext(iter)); } @@ -2399,6 +2417,11 @@ PyArray_Nonzero(PyArrayObject *self) NpyIter_Deallocate(iter); finish: + if (PyErr_Occurred()) { + Py_DECREF(ret); + return NULL; + } + /* if executed `nonzero()` check for miscount due to side-effect */ if (!is_bool && added_count != nonzero_count) { PyErr_SetString(PyExc_RuntimeError, |