summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-01-03 17:48:20 +0100
committerSebastian Berg <sebastianb@nvidia.com>2023-02-19 19:52:04 +0100
commit283f36b05e625928ca16c86633fb30e26342eb97 (patch)
treee3f787dba696a4c446392b5db11cbb2d4b739e14 /numpy
parent7b15e26c8b246095cdd8800d0e065be74ee85447 (diff)
downloadnumpy-283f36b05e625928ca16c86633fb30e26342eb97.tar.gz
WIP: Further fixups and full implementation for structured dtypes
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/dtype_transfer.c186
-rw-r--r--numpy/core/src/multiarray/dtype_traversal.c255
-rw-r--r--numpy/core/src/multiarray/dtype_traversal.h21
-rw-r--r--numpy/core/src/multiarray/refcount.c14
4 files changed, 299 insertions, 177 deletions
diff --git a/numpy/core/src/multiarray/dtype_transfer.c b/numpy/core/src/multiarray/dtype_transfer.c
index 40e0af733..518269f39 100644
--- a/numpy/core/src/multiarray/dtype_transfer.c
+++ b/numpy/core/src/multiarray/dtype_transfer.c
@@ -32,6 +32,7 @@
#include "shape.h"
#include "dtype_transfer.h"
+#include "dtype_traversal.h"
#include "alloc.h"
#include "dtypemeta.h"
#include "array_method.h"
@@ -147,7 +148,7 @@ typedef struct {
NpyAuxData base;
PyArray_GetItemFunc *getitem;
PyArrayObject_fields arr_fields;
- NPY_cast_info decref_src;
+ NPY_traverse_info decref_src;
} _any_to_object_auxdata;
@@ -157,7 +158,7 @@ _any_to_object_auxdata_free(NpyAuxData *auxdata)
_any_to_object_auxdata *data = (_any_to_object_auxdata *)auxdata;
Py_DECREF(data->arr_fields.descr);
- NPY_cast_info_xfree(&data->decref_src);
+ NPY_traverse_info_xfree(&data->decref_src);
PyMem_Free(data);
}
@@ -175,7 +176,7 @@ _any_to_object_auxdata_clone(NpyAuxData *auxdata)
Py_INCREF(res->arr_fields.descr);
if (data->decref_src.func != NULL) {
- if (NPY_cast_info_copy(&res->decref_src, &data->decref_src) < 0) {
+ if (NPY_traverse_info_copy(&res->decref_src, &data->decref_src) < 0) {
NPY_AUXDATA_FREE((NpyAuxData *)res);
return NULL;
}
@@ -216,8 +217,8 @@ _strided_to_strided_any_to_object(
}
if (data->decref_src.func != NULL) {
/* If necessary, clear the input buffer (`move_references`) */
- if (data->decref_src.func(&data->decref_src.context,
- &orig_src, &N, &src_stride, data->decref_src.auxdata) < 0) {
+ if (data->decref_src.func(NULL, data->decref_src.descr,
+ orig_src, N, src_stride, data->decref_src.auxdata) < 0) {
return -1;
}
}
@@ -253,11 +254,11 @@ any_to_object_get_loop(
data->arr_fields.nd = 0;
data->getitem = context->descriptors[0]->f->getitem;
- NPY_cast_info_init(&data->decref_src);
+ NPY_traverse_info_init(&data->decref_src);
if (move_references && PyDataType_REFCHK(context->descriptors[0])) {
NPY_ARRAYMETHOD_FLAGS clear_flags;
- if (get_clear_function(
+ if (PyArray_GetClearFunction(
aligned, strides[0], context->descriptors[0],
&data->decref_src, &clear_flags) < 0) {
NPY_AUXDATA_FREE(*out_transferdata);
@@ -1381,7 +1382,7 @@ typedef struct {
npy_intp N;
NPY_cast_info wrapped;
/* If finish->func is non-NULL the source needs a decref */
- NPY_cast_info decref_src;
+ NPY_traverse_info decref_src;
} _one_to_n_data;
/* transfer data free function */
@@ -1389,7 +1390,7 @@ static void _one_to_n_data_free(NpyAuxData *data)
{
_one_to_n_data *d = (_one_to_n_data *)data;
NPY_cast_info_xfree(&d->wrapped);
- NPY_cast_info_xfree(&d->decref_src);
+ NPY_traverse_info_xfree(&d->decref_src);
PyMem_Free(data);
}
@@ -1408,7 +1409,7 @@ static NpyAuxData *_one_to_n_data_clone(NpyAuxData *data)
newdata->base.clone = &_one_to_n_data_clone;
newdata->N = d->N;
/* Initialize in case of error, or if it is unused */
- NPY_cast_info_init(&newdata->decref_src);
+ NPY_traverse_info_init(&newdata->decref_src);
if (NPY_cast_info_copy(&newdata->wrapped, &d->wrapped) < 0) {
_one_to_n_data_free((NpyAuxData *)newdata);
@@ -1418,7 +1419,7 @@ static NpyAuxData *_one_to_n_data_clone(NpyAuxData *data)
return (NpyAuxData *)newdata;
}
- if (NPY_cast_info_copy(&newdata->decref_src, &d->decref_src) < 0) {
+ if (NPY_traverse_info_copy(&newdata->decref_src, &d->decref_src) < 0) {
_one_to_n_data_free((NpyAuxData *)newdata);
return NULL;
}
@@ -1478,8 +1479,8 @@ _strided_to_strided_one_to_n_with_finish(
return -1;
}
- if (d->decref_src.func(&d->decref_src.context,
- &src, &one_item, &zero_stride, d->decref_src.auxdata) < 0) {
+ if (d->decref_src.func(NULL, d->decref_src.descr,
+ src, one_item, zero_stride, d->decref_src.auxdata) < 0) {
return -1;
}
@@ -1510,7 +1511,7 @@ get_one_to_n_transfer_function(int aligned,
data->base.free = &_one_to_n_data_free;
data->base.clone = &_one_to_n_data_clone;
data->N = N;
- NPY_cast_info_init(&data->decref_src); /* In case of error */
+ NPY_traverse_info_init(&data->decref_src); /* In case of error */
/*
* move_references is set to 0, handled in the wrapping transfer fn,
@@ -1531,7 +1532,7 @@ get_one_to_n_transfer_function(int aligned,
/* If the src object will need a DECREF, set src_dtype */
if (move_references && PyDataType_REFCHK(src_dtype)) {
NPY_ARRAYMETHOD_FLAGS clear_flags;
- if (get_clear_function(
+ if (PyArray_GetClearFunction(
aligned, src_stride, src_dtype,
&data->decref_src, &clear_flags) < 0) {
NPY_AUXDATA_FREE((NpyAuxData *)data);
@@ -1728,8 +1729,8 @@ typedef struct {
typedef struct {
NpyAuxData base;
NPY_cast_info wrapped;
- NPY_cast_info decref_src;
- NPY_cast_info decref_dst; /* The use-case should probably be deprecated */
+ NPY_traverse_info decref_src;
+ NPY_traverse_info decref_dst; /* The use-case should probably be deprecated */
npy_intp src_N, dst_N;
/* This gets a run-length encoded representation of the transfer */
npy_intp run_count;
@@ -1742,8 +1743,8 @@ static void _subarray_broadcast_data_free(NpyAuxData *data)
{
_subarray_broadcast_data *d = (_subarray_broadcast_data *)data;
NPY_cast_info_xfree(&d->wrapped);
- NPY_cast_info_xfree(&d->decref_src);
- NPY_cast_info_xfree(&d->decref_dst);
+ NPY_traverse_info_xfree(&d->decref_src);
+ NPY_traverse_info_xfree(&d->decref_dst);
PyMem_Free(data);
}
@@ -1767,21 +1768,21 @@ static NpyAuxData *_subarray_broadcast_data_clone(NpyAuxData *data)
newdata->run_count = d->run_count;
memcpy(newdata->offsetruns, d->offsetruns, offsetruns_size);
- NPY_cast_info_init(&newdata->decref_src);
- NPY_cast_info_init(&newdata->decref_dst);
+ NPY_traverse_info_init(&newdata->decref_src);
+ NPY_traverse_info_init(&newdata->decref_dst);
if (NPY_cast_info_copy(&newdata->wrapped, &d->wrapped) < 0) {
_subarray_broadcast_data_free((NpyAuxData *)newdata);
return NULL;
}
if (d->decref_src.func != NULL) {
- if (NPY_cast_info_copy(&newdata->decref_src, &d->decref_src) < 0) {
+ if (NPY_traverse_info_copy(&newdata->decref_src, &d->decref_src) < 0) {
_subarray_broadcast_data_free((NpyAuxData *) newdata);
return NULL;
}
}
if (d->decref_dst.func != NULL) {
- if (NPY_cast_info_copy(&newdata->decref_dst, &d->decref_dst) < 0) {
+ if (NPY_traverse_info_copy(&newdata->decref_dst, &d->decref_dst) < 0) {
_subarray_broadcast_data_free((NpyAuxData *) newdata);
return NULL;
}
@@ -1870,8 +1871,8 @@ _strided_to_strided_subarray_broadcast_withrefs(
}
else {
if (d->decref_dst.func != NULL) {
- if (d->decref_dst.func(&d->decref_dst.context,
- &dst_ptr, &count, &dst_subitemsize,
+ if (d->decref_dst.func(NULL, d->decref_dst.descr,
+ dst_ptr, count, dst_subitemsize,
d->decref_dst.auxdata) < 0) {
return -1;
}
@@ -1882,8 +1883,8 @@ _strided_to_strided_subarray_broadcast_withrefs(
}
if (d->decref_src.func != NULL) {
- if (d->decref_src.func(&d->decref_src.context,
- &src, &d->src_N, &src_subitemsize,
+ if (d->decref_src.func(NULL, d->decref_src.descr,
+ src, d->src_N, src_subitemsize,
d->decref_src.auxdata) < 0) {
return -1;
}
@@ -1926,8 +1927,8 @@ get_subarray_broadcast_transfer_function(int aligned,
data->src_N = src_size;
data->dst_N = dst_size;
- NPY_cast_info_init(&data->decref_src);
- NPY_cast_info_init(&data->decref_dst);
+ NPY_traverse_info_init(&data->decref_src);
+ NPY_traverse_info_init(&data->decref_dst);
/*
* move_references is set to 0, handled in the wrapping transfer fn,
@@ -1946,12 +1947,9 @@ get_subarray_broadcast_transfer_function(int aligned,
/* If the src object will need a DECREF */
if (move_references && PyDataType_REFCHK(src_dtype)) {
- if (PyArray_GetDTypeTransferFunction(aligned,
- src_dtype->elsize, 0,
- src_dtype, NULL,
- 1,
- &data->decref_src,
- out_flags) != NPY_SUCCEED) {
+ if (PyArray_GetClearFunction(aligned,
+ src_dtype->elsize, src_dtype,
+ &data->decref_src, out_flags) != NPY_SUCCEED) {
NPY_AUXDATA_FREE((NpyAuxData *)data);
return NPY_FAIL;
}
@@ -1959,12 +1957,9 @@ get_subarray_broadcast_transfer_function(int aligned,
/* If the dst object needs a DECREF to set it to NULL */
if (PyDataType_REFCHK(dst_dtype)) {
- if (PyArray_GetDTypeTransferFunction(aligned,
- dst_dtype->elsize, 0,
- dst_dtype, NULL,
- 1,
- &data->decref_dst,
- out_flags) != NPY_SUCCEED) {
+ if (PyArray_GetClearFunction(aligned,
+ dst_dtype->elsize, dst_dtype,
+ &data->decref_dst, out_flags) != NPY_SUCCEED) {
NPY_AUXDATA_FREE((NpyAuxData *)data);
return NPY_FAIL;
}
@@ -2169,6 +2164,7 @@ typedef struct {
typedef struct {
NpyAuxData base;
npy_intp field_count;
+ NPY_traverse_info decref_func;
_single_field_transfer fields[];
} _field_transfer_data;
@@ -2333,7 +2329,7 @@ get_fields_transfer_function(int NPY_UNUSED(aligned),
*/
if (move_references && PyDataType_REFCHK(src_dtype)) {
NPY_ARRAYMETHOD_FLAGS clear_flags;
- if (get_clear_function(
+ if (PyArray_GetClearFunction(
0, src_stride, src_dtype,
&data->fields[field_count].info,
&clear_flags) < 0) {
@@ -2468,7 +2464,7 @@ typedef struct {
/* The transfer function being wrapped (could likely be stored directly) */
NPY_cast_info wrapped;
/* The src decref function if necessary */
- NPY_cast_info decref_src;
+ NPY_traverse_info decref_src;
} _masked_wrapper_transfer_data;
/* transfer data free function */
@@ -2477,7 +2473,7 @@ _masked_wrapper_transfer_data_free(NpyAuxData *data)
{
_masked_wrapper_transfer_data *d = (_masked_wrapper_transfer_data *)data;
NPY_cast_info_xfree(&d->wrapped);
- NPY_cast_info_xfree(&d->decref_src);
+ NPY_traverse_info_xfree(&d->decref_src);
PyMem_Free(data);
}
@@ -2500,7 +2496,7 @@ _masked_wrapper_transfer_data_clone(NpyAuxData *data)
return NULL;
}
if (d->decref_src.func != NULL) {
- if (NPY_cast_info_copy(&newdata->decref_src, &d->decref_src) < 0) {
+ if (NPY_traverse_info_copy(&newdata->decref_src, &d->decref_src) < 0) {
NPY_AUXDATA_FREE((NpyAuxData *)newdata);
return NULL;
}
@@ -2527,8 +2523,8 @@ _strided_masked_wrapper_clear_function(
/* Skip masked values, still calling decref for move_references */
mask = (npy_bool*)npy_memchr((char *)mask, 0, mask_stride, N,
&subloopsize, 1);
- if (d->decref_src.func(&d->decref_src.context,
- &src, &subloopsize, &src_stride, d->decref_src.auxdata) < 0) {
+ if (d->decref_src.func(NULL, d->decref_src.descr,
+ src, subloopsize, src_stride, d->decref_src.auxdata) < 0) {
return -1;
}
dst += subloopsize * dst_stride;
@@ -2608,49 +2604,6 @@ _cast_no_op(
/*
- * Get a strided loop used for clearing. This looks up the `clearimpl`
- * slot on the DType.
- * The function will error when `clearimpl` is not defined. Note that
- * old-style user-dtypes use the "void" version (m)
- */
-static int
-get_clear_function(
- int aligned, npy_intp stride,
- PyArray_Descr *dtype, NPY_cast_info *cast_info,
- NPY_ARRAYMETHOD_FLAGS *flags)
-{
- NPY_cast_info_init(cast_info);
-
- PyArrayMethodObject *clearimpl = NPY_DT_SLOTS(NPY_DTYPE(dtype))->clearimpl;
- if (clearimpl == NULL) {
- PyErr_Format(PyExc_RuntimeError,
- "Internal error, tried to fetch decref function for the "
- "unsupported DType '%S'.", dtype);
- return -1;
- }
-
- /* Make sure all important fields are either set or cleared */
- Py_INCREF(dtype);
- cast_info->descriptors[0] = dtype;
- cast_info->descriptors[1] = NULL;
- Py_INCREF(clearimpl);
- cast_info->context.method = clearimpl;
- cast_info->context.caller = NULL;
-
- if (clearimpl->get_strided_loop(
- &cast_info->context, aligned, 0, &stride,
- &cast_info->func, &cast_info->auxdata, flags) < 0) {
- Py_CLEAR(cast_info->descriptors[0]);
- Py_CLEAR(cast_info->context.method);
- NPY_cast_info_xfree(cast_info);
- return -1;
- }
-
- return 0;
-}
-
-
-/*
* ********************* Generalized Multistep Cast ************************
*
* New general purpose multiple step cast function when resolve descriptors
@@ -2935,8 +2888,7 @@ _clear_cast_info_after_get_loop_failure(NPY_cast_info *cast_info)
/*
- * Helper for PyArray_GetDTypeTransferFunction, which fetches a single
- * transfer function from the each casting implementation (ArrayMethod).
+ * Fetches a transfer function from the casting implementation (ArrayMethod).
* May set the transfer function to NULL when the cast can be achieved using
* a view.
* TODO: Expand the view functionality for general offsets, not just 0:
@@ -2958,14 +2910,16 @@ _clear_cast_info_after_get_loop_failure(NPY_cast_info *cast_info)
*
* Returns -1 on failure, 0 on success
*/
-static int
-define_cast_for_descrs(
+NPY_NO_EXPORT int
+PyArray_GetDTypeTransferFunction(
int aligned,
npy_intp src_stride, npy_intp dst_stride,
PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype,
int move_references,
NPY_cast_info *cast_info, NPY_ARRAYMETHOD_FLAGS *out_flags)
{
+ assert(dst_dtype != NULL); /* Was previously used for decref */
+
/* Storage for all cast info in case multi-step casting is necessary */
_multistep_castdata castdata;
/* Initialize funcs to NULL to simplify cleanup on error. */
@@ -3117,46 +3071,6 @@ define_cast_for_descrs(
}
-NPY_NO_EXPORT int
-PyArray_GetDTypeTransferFunction(int aligned,
- npy_intp src_stride, npy_intp dst_stride,
- PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype,
- int move_references,
- NPY_cast_info *cast_info,
- NPY_ARRAYMETHOD_FLAGS *out_flags)
-{
- assert(src_dtype != NULL);
-
- /*
- * If one of the dtypes is NULL, we give back either a src decref
- * function or a dst setzero function
- *
- * TODO: Eventually, we may wish to support user dtype with references
- * (including and beyond bare `PyObject *` this may require extending
- * the ArrayMethod API and those paths should likely be split out
- * from this function.)
- */
- if (dst_dtype == NULL) {
- assert(move_references);
- int res = get_clear_function(
- aligned, src_dtype->elsize, src_dtype, cast_info, out_flags);
- if (res < 0) {
- return NPY_FAIL;
- }
- return NPY_SUCCEED;
- }
-
- if (define_cast_for_descrs(aligned,
- src_stride, dst_stride,
- src_dtype, dst_dtype, move_references,
- cast_info, out_flags) < 0) {
- return NPY_FAIL;
- }
-
- return NPY_SUCCEED;
-}
-
-
/*
* Internal wrapping of casts that have to be performed in a "single"
* function (i.e. not by the generic multi-step-cast), but rely on it
@@ -3404,7 +3318,7 @@ PyArray_GetMaskedDTypeTransferFunction(int aligned,
/* If the src object will need a DECREF, get a function to handle that */
if (move_references && PyDataType_REFCHK(src_dtype)) {
NPY_ARRAYMETHOD_FLAGS clear_flags;
- if (get_clear_function(
+ if (PyArray_GetClearFunction(
aligned, src_stride, src_dtype,
&data->decref_src, &clear_flags) < 0) {
NPY_AUXDATA_FREE((NpyAuxData *)data);
@@ -3415,7 +3329,7 @@ PyArray_GetMaskedDTypeTransferFunction(int aligned,
&_strided_masked_wrapper_clear_function;
}
else {
- NPY_cast_info_init(&data->decref_src);
+ NPY_traverse_info_init(&data->decref_src);
cast_info->func = (PyArrayMethod_StridedLoop *)
&_strided_masked_wrapper_transfer_function;
}
diff --git a/numpy/core/src/multiarray/dtype_traversal.c b/numpy/core/src/multiarray/dtype_traversal.c
index 52feb55fe..dec7a4b7a 100644
--- a/numpy/core/src/multiarray/dtype_traversal.c
+++ b/numpy/core/src/multiarray/dtype_traversal.c
@@ -14,10 +14,65 @@
#include "dtypemeta.h"
#include "dtype_traversal.h"
+#include "pyerrors.h"
+#include <stdint.h>
+/* Same as in dtype_transfer.c */
+#define NPY_LOWLEVEL_BUFFER_BLOCKSIZE 128
-/****************** Python Object clear ***********************/
+/*
+ * Generic Clear function helpers:
+ */
+
+static int
+get_clear_function(
+ void *traverse_context, PyArray_Descr *dtype, int aligned,
+ npy_intp stride, NPY_traverse_info *clear_info,
+ NPY_ARRAYMETHOD_FLAGS *flags)
+{
+ NPY_traverse_info_init(clear_info);
+
+ get_simple_loop_function *get_clear = NPY_DT_SLOTS(NPY_DTYPE(dtype))->get_clear_loop;
+ if (get_clear == NULL) {
+ PyErr_Format(PyExc_RuntimeError,
+ "Internal error, tried to fetch decref/clear function for the "
+ "unsupported DType '%S'.", dtype);
+ return -1;
+ }
+
+ if (get_clear(traverse_context, dtype, aligned, stride,
+ &clear_info->func, &clear_info->auxdata, flags) < 0) {
+ /* callee should clean up, but make sure outside debug mode */
+ assert(clear_info->func == NULL);
+ clear_info->func = NULL;
+ return -1;
+ }
+ Py_INCREF(dtype);
+ clear_info->descr = dtype;
+
+ return 0;
+}
+/*
+ * Helper to set up a strided loop used for clearing.
+ * The function will error when called on a dtype which does not have
+ * references (and thus the get_clear_loop slot NULL).
+ * Note that old-style user-dtypes use the "void" version.
+ *
+ * NOTE: This function may have a use for a `traverse_context` at some point
+ * but right now, it is always NULL and only exists to allow adding it
+ * in the future without changing the strided-loop signature.
+ */
+NPY_NO_EXPORT int
+PyArray_GetClearFunction(
+ int aligned, npy_intp stride, PyArray_Descr *dtype,
+ NPY_traverse_info *clear_info, NPY_ARRAYMETHOD_FLAGS *flags)
+{
+ return get_clear_function(NULL, dtype, aligned, stride, clear_info, flags);
+}
+
+
+/****************** Python Object clear ***********************/
static int
clear_object_strided_loop(
@@ -82,7 +137,7 @@ fields_clear_data_free(NpyAuxData *data)
fields_clear_data *d = (fields_clear_data *)data;
for (npy_intp i = 0; i < d->field_count; ++i) {
- NPY_traverse_info_free(&d->fields[i].info);
+ NPY_traverse_info_xfree(&d->fields[i].info);
}
PyMem_Free(d);
}
@@ -125,16 +180,55 @@ fields_clear_data_clone(NpyAuxData *data)
static int
+traverse_fields_function(
+ void *traverse_context, PyArray_Descr *NPY_UNUSED(descr),
+ char *data, npy_intp N, npy_intp stride,
+ NpyAuxData *auxdata)
+{
+ fields_clear_data *d = (fields_clear_data *)auxdata;
+ npy_intp i, field_count = d->field_count;
+
+ /* Do the traversing a block at a time for better memory caching */
+ const npy_intp blocksize = NPY_LOWLEVEL_BUFFER_BLOCKSIZE;
+
+ for (;;) {
+ if (N > blocksize) {
+ for (i = 0; i < field_count; ++i) {
+ single_field_clear_data field = d->fields[i];
+ if (field.info.func(traverse_context,
+ field.info.descr, data + field.src_offset,
+ blocksize, stride, field.info.auxdata) < 0) {
+ return -1;
+ }
+ }
+ N -= NPY_LOWLEVEL_BUFFER_BLOCKSIZE;
+ data += NPY_LOWLEVEL_BUFFER_BLOCKSIZE * stride;
+ }
+ else {
+ for (i = 0; i < field_count; ++i) {
+ single_field_clear_data field = d->fields[i];
+ if (field.info.func(traverse_context,
+ field.info.descr, data + field.src_offset,
+ blocksize, stride, field.info.auxdata) < 0) {
+ return -1;
+ }
+ }
+ return 0;
+ }
+ }
+}
+
+
+static int
get_clear_fields_transfer_function(
- void *NPY_UNUSED(traverse_context), PyArray_Descr *dtype,
- npy_intp stride, simple_loop_function **out_func,
- NpyAuxData **out_transferdata, NPY_ARRAYMETHOD_FLAGS *flags)
+ void *traverse_context, int NPY_UNUSED(aligned),
+ PyArray_Descr *dtype, npy_intp stride, simple_loop_function **out_func,
+ NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags)
{
PyObject *names, *key, *tup, *title;
PyArray_Descr *fld_dtype;
npy_int i, structsize;
Py_ssize_t field_count;
- int src_offset;
names = dtype->names;
field_count = PyTuple_GET_SIZE(dtype->names);
@@ -154,35 +248,131 @@ get_clear_fields_transfer_function(
single_field_clear_data *field = data->fields;
for (i = 0; i < field_count; ++i) {
+ int offset;
+
key = PyTuple_GET_ITEM(names, i);
tup = PyDict_GetItem(dtype->fields, key);
- if (!PyArg_ParseTuple(tup, "Oi|O", &fld_dtype,
- &offset, &title)) {
+ if (!PyArg_ParseTuple(tup, "Oi|O", &fld_dtype, &offset, &title)) {
NPY_AUXDATA_FREE((NpyAuxData *)data);
return NPY_FAIL;
}
if (PyDataType_REFCHK(fld_dtype)) {
NPY_ARRAYMETHOD_FLAGS clear_flags;
if (get_clear_function(
- 0, stride, fld_dtype,
- &field->info, &clear_flags) < 0) {
+ traverse_context, fld_dtype, 0,
+ stride, &field->info, &clear_flags) < 0) {
NPY_AUXDATA_FREE((NpyAuxData *)data);
return NPY_FAIL;
}
*flags = PyArrayMethod_COMBINED_FLAGS(*flags, clear_flags);
- field->src_offset = src_offset;
+ field->src_offset = offset;
data->field_count++;
field++;
}
}
- *out_stransfer = &_strided_to_strided_field_transfer;
- *out_transferdata = (NpyAuxData *)data;
+ *out_func = &traverse_fields_function;
+ *out_auxdata = (NpyAuxData *)data;
return NPY_SUCCEED;
}
+
+typedef struct {
+ NpyAuxData base;
+ npy_intp count;
+ NPY_traverse_info info;
+} subarray_clear_data;
+
+
+/* transfer data free function */
+static void
+subarray_clear_data_free(NpyAuxData *data)
+{
+ subarray_clear_data *d = (subarray_clear_data *)data;
+
+ NPY_traverse_info_xfree(&d->info);
+ PyMem_Free(d);
+}
+
+
+/* transfer data copy function */
+static NpyAuxData *
+subarray_clear_data_clone(NpyAuxData *data)
+{
+ subarray_clear_data *d = (subarray_clear_data *)data;
+
+ /* Allocate the data and populate it */
+ subarray_clear_data *newdata = PyMem_Malloc(sizeof(subarray_clear_data));
+ if (newdata == NULL) {
+ return NULL;
+ }
+ newdata->count = d->count;
+
+ if (NPY_traverse_info_copy(&newdata->info, &d->info) < 0) {
+ PyMem_Free(newdata);
+ return NULL;
+ }
+
+ return (NpyAuxData *)newdata;
+}
+
+
+static int
+traverse_subarray_func(
+ void *traverse_context, PyArray_Descr *NPY_UNUSED(descr),
+ char *data, npy_intp N, npy_intp stride,
+ NpyAuxData *auxdata)
+{
+ subarray_clear_data *subarr_data = (subarray_clear_data *)auxdata;
+
+ simple_loop_function *func = subarr_data->info.func;
+ PyArray_Descr *sub_descr = subarr_data->info.descr;
+ npy_intp sub_N = subarr_data->count;
+ NpyAuxData *sub_auxdata = subarr_data->info.auxdata;
+ npy_intp sub_stride = sub_descr->elsize;
+
+ while (N--) {
+ if (func(traverse_context, sub_descr, data,
+ sub_N, sub_stride, sub_auxdata) < 0) {
+ return -1;
+ }
+ data += stride;
+ }
+ return 0;
+}
+
+
+static int
+get_subarray_clear_func(
+ void *traverse_context, int aligned, PyArray_Descr *dtype,
+ npy_intp size, npy_intp stride, simple_loop_function **out_func,
+ NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags)
+{
+ subarray_clear_data *auxdata = PyMem_Malloc(sizeof(subarray_clear_data));
+ if (auxdata == NULL) {
+ PyErr_NoMemory();
+ return -1;
+ }
+
+ auxdata->count = size;
+ auxdata->base.free = &subarray_clear_data_free;
+ auxdata->base.clone = &subarray_clear_data_clone;
+
+ if (get_clear_function(
+ traverse_context, dtype, aligned,
+ dtype->elsize, &auxdata->info, flags) < 0) {
+ PyMem_Free(auxdata);
+ return -1;
+ }
+ *out_func = &traverse_subarray_func;
+ *out_auxdata = (NpyAuxData *)auxdata;
+
+ return 0;
+}
+
+
static int
clear_no_op(
void *NPY_UNUSED(traverse_context), PyArray_Descr *NPY_UNUSED(descr),
@@ -195,36 +385,37 @@ clear_no_op(
NPY_NO_EXPORT int
npy_get_clear_void_and_legacy_user_dtype_loop(
- void *NPY_UNUSED(traverse_context), PyArray_Descr *dtype,
+ void *traverse_context, int aligned, PyArray_Descr *dtype,
npy_intp stride, simple_loop_function **out_func,
- NpyAuxData **out_transferdata, NPY_ARRAYMETHOD_FLAGS *flags)
+ NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags)
{
- /* If there are no references, it's a nop (path should not be hit?) */
+ /*
+ * If there are no references, it's a nop. This path should not be hit
+ * but structured dtypes are tricky when a dtype which included references
+ * was sliced to not include any.
+ */
if (!PyDataType_REFCHK(dtype)) {
- *out_loop = &clear_no_op;
- *out_transferdata = NULL;
+ *out_func = &clear_no_op;
+ *out_auxdata = NULL;
assert(0);
return 0;
}
if (PyDataType_HASSUBARRAY(dtype)) {
- PyArray_Dims src_shape = {NULL, -1};
- npy_intp src_size;
+ PyArray_Dims shape = {NULL, -1};
+ npy_intp size;
- if (!(PyArray_IntpConverter(dtype->subarray->shape,
- &src_shape))) {
+ if (!(PyArray_IntpConverter(dtype->subarray->shape, &shape))) {
PyErr_SetString(PyExc_ValueError,
"invalid subarray shape");
return -1;
}
- src_size = PyArray_MultiplyList(src_shape.ptr, src_shape.len);
- npy_free_cache_dim_obj(src_shape);
-
- if (get_n_to_n_transfer_function(aligned,
- stride, 0,
- dtype->subarray->base, NULL, 1, src_size,
- out_loop, out_transferdata,
- flags) != NPY_SUCCEED) {
+ size = PyArray_MultiplyList(shape.ptr, shape.len);
+ npy_free_cache_dim_obj(shape);
+
+ if (get_subarray_clear_func(
+ traverse_context, aligned, dtype->subarray->base, size, stride,
+ out_func, out_auxdata, flags) != NPY_SUCCEED) {
return -1;
}
@@ -233,8 +424,8 @@ npy_get_clear_void_and_legacy_user_dtype_loop(
/* If there are fields, need to do each field */
else if (PyDataType_HASFIELDS(dtype)) {
if (get_clear_fields_transfer_function(
- aligned, stride, dtype,
- out_loop, out_transferdata, flags) < 0) {
+ traverse_context, aligned, dtype, stride,
+ out_func, out_auxdata, flags) < 0) {
return -1;
}
return 0;
diff --git a/numpy/core/src/multiarray/dtype_traversal.h b/numpy/core/src/multiarray/dtype_traversal.h
index 19c3a6de1..919c49eb7 100644
--- a/numpy/core/src/multiarray/dtype_traversal.h
+++ b/numpy/core/src/multiarray/dtype_traversal.h
@@ -33,7 +33,8 @@ typedef int (simple_loop_function)(
/* Simplified get_loop function specific to dtype traversal */
typedef int (get_simple_loop_function)(
- void *traverse_context, int aligned, npy_intp fixed_stride,
+ void *traverse_context, PyArray_Descr *descr,
+ int aligned, npy_intp fixed_stride,
simple_loop_function **out_loop, NpyAuxData **out_auxdata,
NPY_ARRAYMETHOD_FLAGS *flags);
@@ -64,8 +65,18 @@ typedef struct {
static inline void
-NPY_traverse_info_free(NPY_traverse_info *traverse_info)
+NPY_traverse_info_init(NPY_traverse_info *cast_info)
{
+ cast_info->func = NULL; /* mark as uninitialized. */
+}
+
+
+static inline void
+NPY_traverse_info_xfree(NPY_traverse_info *traverse_info)
+{
+ if (traverse_info->func == NULL) {
+ return;
+ }
traverse_info->func = NULL;
NPY_AUXDATA_FREE(traverse_info->auxdata);
Py_DECREF(traverse_info->descr);
@@ -91,4 +102,10 @@ NPY_traverse_info_copy(
}
+NPY_NO_EXPORT int
+PyArray_GetClearFunction(
+ int aligned, npy_intp stride, PyArray_Descr *dtype,
+ NPY_traverse_info *clear_info, NPY_ARRAYMETHOD_FLAGS *flags);
+
+
#endif /* NUMPY_CORE_SRC_MULTIARRAY_DTYPE_TRAVERSAL_H_ */ \ No newline at end of file
diff --git a/numpy/core/src/multiarray/refcount.c b/numpy/core/src/multiarray/refcount.c
index daa5bb289..bb2fb982f 100644
--- a/numpy/core/src/multiarray/refcount.c
+++ b/numpy/core/src/multiarray/refcount.c
@@ -3,7 +3,7 @@
* section in the numpy reference for C-API.
*/
#include "array_method.h"
-#include "dtype_transfer.h"
+#include "dtype_traversal.h"
#include "lowlevel_strided_loops.h"
#include "pyerrors.h"
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
@@ -43,16 +43,16 @@ PyArray_ClearData(
return 0;
}
- NPY_cast_info cast_info;
+ NPY_traverse_info clear_info;
NPY_ARRAYMETHOD_FLAGS flags;
- if (PyArray_GetDTypeTransferFunction(
- aligned, stride, 0, descr, NULL, 1, &cast_info, &flags) < 0) {
+ if (PyArray_GetClearFunction(
+ aligned, stride, descr, &clear_info, &flags) < 0) {
return -1;
}
- int res = cast_info.func(
- &cast_info.context, &data, &size, &stride, cast_info.auxdata);
- NPY_cast_info_xfree(&cast_info);
+ int res = clear_info.func(
+ NULL, clear_info.descr, data, size, stride, clear_info.auxdata);
+ NPY_traverse_info_xfree(&clear_info);
return res;
}