summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Goldbaum <nathan.goldbaum@gmail.com>2023-05-11 14:53:49 -0600
committerNathan Goldbaum <nathan.goldbaum@gmail.com>2023-05-12 11:11:53 -0600
commit4116b4be7a1b4b464e5a91dbe465cb469369e1e0 (patch)
tree32b3ca634cd7c41fe1c9ff0f1199e99aef14eabc
parentf45f692abac62265b760c0a249e8b417707c47d1 (diff)
downloadnumpy-4116b4be7a1b4b464e5a91dbe465cb469369e1e0.tar.gz
MAINT: do not use copyswapn in array sorting internals
-rw-r--r--numpy/core/src/multiarray/item_selection.c106
1 files changed, 91 insertions, 15 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 676a2a6b4..f42ae7c2d 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -1122,10 +1122,10 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
npy_intp elsize = (npy_intp)PyArray_ITEMSIZE(op);
npy_intp astride = PyArray_STRIDE(op, axis);
int swap = PyArray_ISBYTESWAPPED(op);
- int needcopy = !IsAligned(op) || swap || astride != elsize;
+ int is_aligned = IsAligned(op);
+ int needcopy = !is_aligned || swap || astride != elsize;
int needs_api = PyDataType_FLAGCHK(PyArray_DESCR(op), NPY_NEEDS_PYAPI);
- PyArray_CopySwapNFunc *copyswapn = PyArray_DESCR(op)->f->copyswapn;
char *buffer = NULL;
PyArrayIterObject *it;
@@ -1133,6 +1133,12 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
int ret = 0;
+ PyArray_Descr *descr = PyArray_DESCR(op);
+ PyArray_Descr *odescr = NULL;
+
+ NPY_cast_info to_cast_info = {.func = NULL};
+ NPY_cast_info from_cast_info = {.func = NULL};
+
NPY_BEGIN_THREADS_DEF;
/* Check if there is any sorting to do */
@@ -1157,18 +1163,49 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
ret = -1;
goto fail;
}
- if (PyDataType_FLAGCHK(PyArray_DESCR(op), NPY_NEEDS_INIT)) {
+ if (PyDataType_FLAGCHK(descr, NPY_NEEDS_INIT)) {
memset(buffer, 0, N * elsize);
}
+
+ if (swap) {
+ odescr = PyArray_DescrNewByteorder(descr, NPY_SWAP);
+ }
+ else {
+ odescr = descr;
+ Py_INCREF(odescr);
+ }
+
+ NPY_ARRAYMETHOD_FLAGS to_transfer_flags;
+
+ if (PyArray_GetDTypeTransferFunction(
+ is_aligned, astride, elsize, descr, odescr, 0, &to_cast_info,
+ &to_transfer_flags) != NPY_SUCCEED) {
+ goto fail;
+ }
+
+ NPY_ARRAYMETHOD_FLAGS from_transfer_flags;
+
+ if (PyArray_GetDTypeTransferFunction(
+ is_aligned, elsize, astride, odescr, descr, 0, &from_cast_info,
+ &from_transfer_flags) != NPY_SUCCEED) {
+ goto fail;
+ }
}
- NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(op));
+ NPY_BEGIN_THREADS_DESCR(descr);
while (size--) {
char *bufptr = it->dataptr;
if (needcopy) {
- copyswapn(buffer, elsize, it->dataptr, astride, N, swap, op);
+ char *args[2] = {it->dataptr, buffer};
+ npy_intp strides[2] = {astride, elsize};
+
+ if (NPY_UNLIKELY(to_cast_info.func(
+ &to_cast_info.context, args, &N, strides,
+ to_cast_info.auxdata) < 0)) {
+ goto fail;
+ }
bufptr = buffer;
}
/*
@@ -1204,18 +1241,26 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
}
if (needcopy) {
- copyswapn(it->dataptr, astride, buffer, elsize, N, swap, op);
+ char *args[2] = {buffer, it->dataptr};
+ npy_intp strides[2] = {elsize, astride};
+
+ if (NPY_UNLIKELY(from_cast_info.func(
+ &from_cast_info.context, args, &N, strides,
+ from_cast_info.auxdata) < 0)) {
+ goto fail;
+ }
}
PyArray_ITER_NEXT(it);
}
fail:
- NPY_END_THREADS_DESCR(PyArray_DESCR(op));
+ NPY_END_THREADS_DESCR(descr);
/* cleanup internal buffer */
if (needcopy) {
- PyArray_ClearBuffer(PyArray_DESCR(op), buffer, elsize, N, 1);
+ PyArray_ClearBuffer(odescr, buffer, elsize, N, 1);
PyDataMem_UserFREE(buffer, N * elsize, mem_handler);
+ Py_DECREF(odescr);
}
if (ret < 0 && !PyErr_Occurred()) {
/* Out of memory during sorting or buffer creation */
@@ -1223,6 +1268,8 @@ fail:
}
Py_DECREF(it);
Py_DECREF(mem_handler);
+ NPY_cast_info_xfree(&to_cast_info);
+ NPY_cast_info_xfree(&from_cast_info);
return ret;
}
@@ -1236,11 +1283,11 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
npy_intp elsize = (npy_intp)PyArray_ITEMSIZE(op);
npy_intp astride = PyArray_STRIDE(op, axis);
int swap = PyArray_ISBYTESWAPPED(op);
- int needcopy = !IsAligned(op) || swap || astride != elsize;
+ int is_aligned = IsAligned(op);
+ int needcopy = !is_aligned || swap || astride != elsize;
int needs_api = PyDataType_FLAGCHK(PyArray_DESCR(op), NPY_NEEDS_PYAPI);
int needidxbuffer;
- PyArray_CopySwapNFunc *copyswapn = PyArray_DESCR(op)->f->copyswapn;
char *valbuffer = NULL;
npy_intp *idxbuffer = NULL;
@@ -1252,6 +1299,12 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
int ret = 0;
+ PyArray_Descr *descr = PyArray_DESCR(op);
+ PyArray_Descr *odescr = NULL;
+
+ NPY_ARRAYMETHOD_FLAGS transfer_flags;
+ NPY_cast_info cast_info = {.func = NULL};
+
NPY_BEGIN_THREADS_DEF;
PyObject *mem_handler = PyDataMem_GetHandler();
@@ -1290,9 +1343,23 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
ret = -1;
goto fail;
}
- if (PyDataType_FLAGCHK(PyArray_DESCR(op), NPY_NEEDS_INIT)) {
+ if (PyDataType_FLAGCHK(descr, NPY_NEEDS_INIT)) {
memset(valbuffer, 0, N * elsize);
}
+
+ if (swap) {
+ odescr = PyArray_DescrNewByteorder(descr, NPY_SWAP);
+ }
+ else {
+ odescr = descr;
+ Py_INCREF(odescr);
+ }
+
+ if (PyArray_GetDTypeTransferFunction(
+ is_aligned, astride, elsize, descr, odescr, 0, &cast_info,
+ &transfer_flags) != NPY_SUCCEED) {
+ goto fail;
+ }
}
if (needidxbuffer) {
@@ -1304,7 +1371,7 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
}
}
- NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(op));
+ NPY_BEGIN_THREADS_DESCR(descr);
while (size--) {
char *valptr = it->dataptr;
@@ -1312,7 +1379,14 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
npy_intp *iptr, i;
if (needcopy) {
- copyswapn(valbuffer, elsize, it->dataptr, astride, N, swap, op);
+ char *args[2] = {it->dataptr, valbuffer};
+ npy_intp strides[2] = {astride, elsize};
+
+ if (NPY_UNLIKELY(cast_info.func(
+ &cast_info.context, args, &N, strides,
+ cast_info.auxdata) < 0)) {
+ goto fail;
+ }
valptr = valbuffer;
}
@@ -1366,11 +1440,12 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
}
fail:
- NPY_END_THREADS_DESCR(PyArray_DESCR(op));
+ NPY_END_THREADS_DESCR(descr);
/* cleanup internal buffers */
if (needcopy) {
- PyArray_ClearBuffer(PyArray_DESCR(op), valbuffer, elsize, N, 1);
+ PyArray_ClearBuffer(odescr, valbuffer, elsize, N, 1);
PyDataMem_UserFREE(valbuffer, N * elsize, mem_handler);
+ Py_DECREF(odescr);
}
PyDataMem_UserFREE(idxbuffer, N * sizeof(npy_intp), mem_handler);
if (ret < 0) {
@@ -1384,6 +1459,7 @@ fail:
Py_XDECREF(it);
Py_XDECREF(rit);
Py_DECREF(mem_handler);
+ NPY_cast_info_xfree(&cast_info);
return (PyObject *)rop;
}