diff options
author | Nathan Goldbaum <nathan.goldbaum@gmail.com> | 2023-05-02 11:13:25 -0600 |
---|---|---|
committer | Nathan Goldbaum <nathan.goldbaum@gmail.com> | 2023-05-02 11:13:25 -0600 |
commit | 4a630dca297c72581599f05801b162ffc0046011 (patch) | |
tree | 1052b1288c923ceb6193e0951ec410eb9aadad92 /numpy/core | |
parent | 525c35b5083da99dec8a5756d1b86099b2cf0c6b (diff) | |
download | numpy-4a630dca297c72581599f05801b162ffc0046011.tar.gz |
MAINT: refactor PyArray_Repeat to avoid PyArray_INCREF
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 29 |
1 files changed, 27 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 508b830f0..1fe051789 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -798,6 +798,10 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) aop = (PyArrayObject *)ap; n = PyArray_DIM(aop, axis); + NPY_cast_info cast_info; + NPY_ARRAYMETHOD_FLAGS flags; + NPY_cast_info_init(&cast_info); + int needs_refcounting = PyDataType_REFCHK(PyArray_DESCR(aop)); if (!broadcast && PyArray_SIZE(repeats) != n) { PyErr_Format(PyExc_ValueError, @@ -844,11 +848,31 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) for (i = 0; i < axis; i++) { n_outer *= PyArray_DIMS(aop)[i]; } + + if (needs_refcounting) { + if (PyArray_GetDTypeTransferFunction( + 1, chunk, chunk, PyArray_DESCR(aop), PyArray_DESCR(aop), 0, + &cast_info, &flags) < 0) { + goto fail; + } + } + for (i = 0; i < n_outer; i++) { for (j = 0; j < n; j++) { npy_intp tmp = broadcast ? counts[0] : counts[j]; for (k = 0; k < tmp; k++) { - memcpy(new_data, old_data, chunk); + if (!needs_refcounting) { + memcpy(new_data, old_data, chunk); + } + else { + char *data[2] = {old_data, new_data}; + npy_intp strides[2] = {chunk, chunk}; + npy_intp one = 1; + if (cast_info.func(&cast_info.context, data, &one, strides, + cast_info.auxdata) < 0) { + goto fail; + } + } new_data += chunk; } old_data += chunk; @@ -856,14 +880,15 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) } Py_DECREF(repeats); - PyArray_INCREF(ret); Py_XDECREF(aop); + NPY_cast_info_xfree(&cast_info); return (PyObject *)ret; fail: Py_DECREF(repeats); Py_XDECREF(aop); Py_XDECREF(ret); + NPY_cast_info_xfree(&cast_info); return NULL; } |