summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorNathan Goldbaum <nathan.goldbaum@gmail.com>2023-05-02 11:13:25 -0600
committerNathan Goldbaum <nathan.goldbaum@gmail.com>2023-05-02 11:13:25 -0600
commit4a630dca297c72581599f05801b162ffc0046011 (patch)
tree1052b1288c923ceb6193e0951ec410eb9aadad92 /numpy/core
parent525c35b5083da99dec8a5756d1b86099b2cf0c6b (diff)
downloadnumpy-4a630dca297c72581599f05801b162ffc0046011.tar.gz
MAINT: refactor PyArray_Repeat to avoid PyArray_INCREF
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/item_selection.c29
1 files changed, 27 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 508b830f0..1fe051789 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -798,6 +798,10 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
aop = (PyArrayObject *)ap;
n = PyArray_DIM(aop, axis);
+ NPY_cast_info cast_info;
+ NPY_ARRAYMETHOD_FLAGS flags;
+ NPY_cast_info_init(&cast_info);
+ int needs_refcounting = PyDataType_REFCHK(PyArray_DESCR(aop));
if (!broadcast && PyArray_SIZE(repeats) != n) {
PyErr_Format(PyExc_ValueError,
@@ -844,11 +848,31 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
for (i = 0; i < axis; i++) {
n_outer *= PyArray_DIMS(aop)[i];
}
+
+ if (needs_refcounting) {
+ if (PyArray_GetDTypeTransferFunction(
+ 1, chunk, chunk, PyArray_DESCR(aop), PyArray_DESCR(aop), 0,
+ &cast_info, &flags) < 0) {
+ goto fail;
+ }
+ }
+
for (i = 0; i < n_outer; i++) {
for (j = 0; j < n; j++) {
npy_intp tmp = broadcast ? counts[0] : counts[j];
for (k = 0; k < tmp; k++) {
- memcpy(new_data, old_data, chunk);
+ if (!needs_refcounting) {
+ memcpy(new_data, old_data, chunk);
+ }
+ else {
+ char *data[2] = {old_data, new_data};
+ npy_intp strides[2] = {chunk, chunk};
+ npy_intp one = 1;
+ if (cast_info.func(&cast_info.context, data, &one, strides,
+ cast_info.auxdata) < 0) {
+ goto fail;
+ }
+ }
new_data += chunk;
}
old_data += chunk;
@@ -856,14 +880,15 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
}
Py_DECREF(repeats);
- PyArray_INCREF(ret);
Py_XDECREF(aop);
+ NPY_cast_info_xfree(&cast_info);
return (PyObject *)ret;
fail:
Py_DECREF(repeats);
Py_XDECREF(aop);
Py_XDECREF(ret);
+ NPY_cast_info_xfree(&cast_info);
return NULL;
}