diff options
author | Sebastian Berg <sebastianb@nvidia.com> | 2023-01-03 17:48:20 +0100 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2023-02-19 19:52:04 +0100 |
commit | 283f36b05e625928ca16c86633fb30e26342eb97 (patch) | |
tree | e3f787dba696a4c446392b5db11cbb2d4b739e14 /numpy | |
parent | 7b15e26c8b246095cdd8800d0e065be74ee85447 (diff) | |
download | numpy-283f36b05e625928ca16c86633fb30e26342eb97.tar.gz |
WIP: Further fixups and full implementation for structured dtypes
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/dtype_transfer.c | 186 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtype_traversal.c | 255 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtype_traversal.h | 21 | ||||
-rw-r--r-- | numpy/core/src/multiarray/refcount.c | 14 |
4 files changed, 299 insertions, 177 deletions
diff --git a/numpy/core/src/multiarray/dtype_transfer.c b/numpy/core/src/multiarray/dtype_transfer.c index 40e0af733..518269f39 100644 --- a/numpy/core/src/multiarray/dtype_transfer.c +++ b/numpy/core/src/multiarray/dtype_transfer.c @@ -32,6 +32,7 @@ #include "shape.h" #include "dtype_transfer.h" +#include "dtype_traversal.h" #include "alloc.h" #include "dtypemeta.h" #include "array_method.h" @@ -147,7 +148,7 @@ typedef struct { NpyAuxData base; PyArray_GetItemFunc *getitem; PyArrayObject_fields arr_fields; - NPY_cast_info decref_src; + NPY_traverse_info decref_src; } _any_to_object_auxdata; @@ -157,7 +158,7 @@ _any_to_object_auxdata_free(NpyAuxData *auxdata) _any_to_object_auxdata *data = (_any_to_object_auxdata *)auxdata; Py_DECREF(data->arr_fields.descr); - NPY_cast_info_xfree(&data->decref_src); + NPY_traverse_info_xfree(&data->decref_src); PyMem_Free(data); } @@ -175,7 +176,7 @@ _any_to_object_auxdata_clone(NpyAuxData *auxdata) Py_INCREF(res->arr_fields.descr); if (data->decref_src.func != NULL) { - if (NPY_cast_info_copy(&res->decref_src, &data->decref_src) < 0) { + if (NPY_traverse_info_copy(&res->decref_src, &data->decref_src) < 0) { NPY_AUXDATA_FREE((NpyAuxData *)res); return NULL; } @@ -216,8 +217,8 @@ _strided_to_strided_any_to_object( } if (data->decref_src.func != NULL) { /* If necessary, clear the input buffer (`move_references`) */ - if (data->decref_src.func(&data->decref_src.context, - &orig_src, &N, &src_stride, data->decref_src.auxdata) < 0) { + if (data->decref_src.func(NULL, data->decref_src.descr, + orig_src, N, src_stride, data->decref_src.auxdata) < 0) { return -1; } } @@ -253,11 +254,11 @@ any_to_object_get_loop( data->arr_fields.nd = 0; data->getitem = context->descriptors[0]->f->getitem; - NPY_cast_info_init(&data->decref_src); + NPY_traverse_info_init(&data->decref_src); if (move_references && PyDataType_REFCHK(context->descriptors[0])) { NPY_ARRAYMETHOD_FLAGS clear_flags; - if (get_clear_function( + if (PyArray_GetClearFunction( aligned, strides[0], context->descriptors[0], &data->decref_src, &clear_flags) < 0) { NPY_AUXDATA_FREE(*out_transferdata); @@ -1381,7 +1382,7 @@ typedef struct { npy_intp N; NPY_cast_info wrapped; /* If finish->func is non-NULL the source needs a decref */ - NPY_cast_info decref_src; + NPY_traverse_info decref_src; } _one_to_n_data; /* transfer data free function */ @@ -1389,7 +1390,7 @@ static void _one_to_n_data_free(NpyAuxData *data) { _one_to_n_data *d = (_one_to_n_data *)data; NPY_cast_info_xfree(&d->wrapped); - NPY_cast_info_xfree(&d->decref_src); + NPY_traverse_info_xfree(&d->decref_src); PyMem_Free(data); } @@ -1408,7 +1409,7 @@ static NpyAuxData *_one_to_n_data_clone(NpyAuxData *data) newdata->base.clone = &_one_to_n_data_clone; newdata->N = d->N; /* Initialize in case of error, or if it is unused */ - NPY_cast_info_init(&newdata->decref_src); + NPY_traverse_info_init(&newdata->decref_src); if (NPY_cast_info_copy(&newdata->wrapped, &d->wrapped) < 0) { _one_to_n_data_free((NpyAuxData *)newdata); @@ -1418,7 +1419,7 @@ static NpyAuxData *_one_to_n_data_clone(NpyAuxData *data) return (NpyAuxData *)newdata; } - if (NPY_cast_info_copy(&newdata->decref_src, &d->decref_src) < 0) { + if (NPY_traverse_info_copy(&newdata->decref_src, &d->decref_src) < 0) { _one_to_n_data_free((NpyAuxData *)newdata); return NULL; } @@ -1478,8 +1479,8 @@ _strided_to_strided_one_to_n_with_finish( return -1; } - if (d->decref_src.func(&d->decref_src.context, - &src, &one_item, &zero_stride, d->decref_src.auxdata) < 0) { + if (d->decref_src.func(NULL, d->decref_src.descr, + src, one_item, zero_stride, d->decref_src.auxdata) < 0) { return -1; } @@ -1510,7 +1511,7 @@ get_one_to_n_transfer_function(int aligned, data->base.free = &_one_to_n_data_free; data->base.clone = &_one_to_n_data_clone; data->N = N; - NPY_cast_info_init(&data->decref_src); /* In case of error */ + NPY_traverse_info_init(&data->decref_src); /* In case of error */ /* * move_references is set to 0, handled in the wrapping transfer fn, @@ -1531,7 +1532,7 @@ get_one_to_n_transfer_function(int aligned, /* If the src object will need a DECREF, set src_dtype */ if (move_references && PyDataType_REFCHK(src_dtype)) { NPY_ARRAYMETHOD_FLAGS clear_flags; - if (get_clear_function( + if (PyArray_GetClearFunction( aligned, src_stride, src_dtype, &data->decref_src, &clear_flags) < 0) { NPY_AUXDATA_FREE((NpyAuxData *)data); @@ -1728,8 +1729,8 @@ typedef struct { typedef struct { NpyAuxData base; NPY_cast_info wrapped; - NPY_cast_info decref_src; - NPY_cast_info decref_dst; /* The use-case should probably be deprecated */ + NPY_traverse_info decref_src; + NPY_traverse_info decref_dst; /* The use-case should probably be deprecated */ npy_intp src_N, dst_N; /* This gets a run-length encoded representation of the transfer */ npy_intp run_count; @@ -1742,8 +1743,8 @@ static void _subarray_broadcast_data_free(NpyAuxData *data) { _subarray_broadcast_data *d = (_subarray_broadcast_data *)data; NPY_cast_info_xfree(&d->wrapped); - NPY_cast_info_xfree(&d->decref_src); - NPY_cast_info_xfree(&d->decref_dst); + NPY_traverse_info_xfree(&d->decref_src); + NPY_traverse_info_xfree(&d->decref_dst); PyMem_Free(data); } @@ -1767,21 +1768,21 @@ static NpyAuxData *_subarray_broadcast_data_clone(NpyAuxData *data) newdata->run_count = d->run_count; memcpy(newdata->offsetruns, d->offsetruns, offsetruns_size); - NPY_cast_info_init(&newdata->decref_src); - NPY_cast_info_init(&newdata->decref_dst); + NPY_traverse_info_init(&newdata->decref_src); + NPY_traverse_info_init(&newdata->decref_dst); if (NPY_cast_info_copy(&newdata->wrapped, &d->wrapped) < 0) { _subarray_broadcast_data_free((NpyAuxData *)newdata); return NULL; } if (d->decref_src.func != NULL) { - if (NPY_cast_info_copy(&newdata->decref_src, &d->decref_src) < 0) { + if (NPY_traverse_info_copy(&newdata->decref_src, &d->decref_src) < 0) { _subarray_broadcast_data_free((NpyAuxData *) newdata); return NULL; } } if (d->decref_dst.func != NULL) { - if (NPY_cast_info_copy(&newdata->decref_dst, &d->decref_dst) < 0) { + if (NPY_traverse_info_copy(&newdata->decref_dst, &d->decref_dst) < 0) { _subarray_broadcast_data_free((NpyAuxData *) newdata); return NULL; } @@ -1870,8 +1871,8 @@ _strided_to_strided_subarray_broadcast_withrefs( } else { if (d->decref_dst.func != NULL) { - if (d->decref_dst.func(&d->decref_dst.context, - &dst_ptr, &count, &dst_subitemsize, + if (d->decref_dst.func(NULL, d->decref_dst.descr, + dst_ptr, count, dst_subitemsize, d->decref_dst.auxdata) < 0) { return -1; } @@ -1882,8 +1883,8 @@ _strided_to_strided_subarray_broadcast_withrefs( } if (d->decref_src.func != NULL) { - if (d->decref_src.func(&d->decref_src.context, - &src, &d->src_N, &src_subitemsize, + if (d->decref_src.func(NULL, d->decref_src.descr, + src, d->src_N, src_subitemsize, d->decref_src.auxdata) < 0) { return -1; } @@ -1926,8 +1927,8 @@ get_subarray_broadcast_transfer_function(int aligned, data->src_N = src_size; data->dst_N = dst_size; - NPY_cast_info_init(&data->decref_src); - NPY_cast_info_init(&data->decref_dst); + NPY_traverse_info_init(&data->decref_src); + NPY_traverse_info_init(&data->decref_dst); /* * move_references is set to 0, handled in the wrapping transfer fn, @@ -1946,12 +1947,9 @@ get_subarray_broadcast_transfer_function(int aligned, /* If the src object will need a DECREF */ if (move_references && PyDataType_REFCHK(src_dtype)) { - if (PyArray_GetDTypeTransferFunction(aligned, - src_dtype->elsize, 0, - src_dtype, NULL, - 1, - &data->decref_src, - out_flags) != NPY_SUCCEED) { + if (PyArray_GetClearFunction(aligned, + src_dtype->elsize, src_dtype, + &data->decref_src, out_flags) != NPY_SUCCEED) { NPY_AUXDATA_FREE((NpyAuxData *)data); return NPY_FAIL; } @@ -1959,12 +1957,9 @@ get_subarray_broadcast_transfer_function(int aligned, /* If the dst object needs a DECREF to set it to NULL */ if (PyDataType_REFCHK(dst_dtype)) { - if (PyArray_GetDTypeTransferFunction(aligned, - dst_dtype->elsize, 0, - dst_dtype, NULL, - 1, - &data->decref_dst, - out_flags) != NPY_SUCCEED) { + if (PyArray_GetClearFunction(aligned, + dst_dtype->elsize, dst_dtype, + &data->decref_dst, out_flags) != NPY_SUCCEED) { NPY_AUXDATA_FREE((NpyAuxData *)data); return NPY_FAIL; } @@ -2169,6 +2164,7 @@ typedef struct { typedef struct { NpyAuxData base; npy_intp field_count; + NPY_traverse_info decref_func; _single_field_transfer fields[]; } _field_transfer_data; @@ -2333,7 +2329,7 @@ get_fields_transfer_function(int NPY_UNUSED(aligned), */ if (move_references && PyDataType_REFCHK(src_dtype)) { NPY_ARRAYMETHOD_FLAGS clear_flags; - if (get_clear_function( + if (PyArray_GetClearFunction( 0, src_stride, src_dtype, &data->fields[field_count].info, &clear_flags) < 0) { @@ -2468,7 +2464,7 @@ typedef struct { /* The transfer function being wrapped (could likely be stored directly) */ NPY_cast_info wrapped; /* The src decref function if necessary */ - NPY_cast_info decref_src; + NPY_traverse_info decref_src; } _masked_wrapper_transfer_data; /* transfer data free function */ @@ -2477,7 +2473,7 @@ _masked_wrapper_transfer_data_free(NpyAuxData *data) { _masked_wrapper_transfer_data *d = (_masked_wrapper_transfer_data *)data; NPY_cast_info_xfree(&d->wrapped); - NPY_cast_info_xfree(&d->decref_src); + NPY_traverse_info_xfree(&d->decref_src); PyMem_Free(data); } @@ -2500,7 +2496,7 @@ _masked_wrapper_transfer_data_clone(NpyAuxData *data) return NULL; } if (d->decref_src.func != NULL) { - if (NPY_cast_info_copy(&newdata->decref_src, &d->decref_src) < 0) { + if (NPY_traverse_info_copy(&newdata->decref_src, &d->decref_src) < 0) { NPY_AUXDATA_FREE((NpyAuxData *)newdata); return NULL; } @@ -2527,8 +2523,8 @@ _strided_masked_wrapper_clear_function( /* Skip masked values, still calling decref for move_references */ mask = (npy_bool*)npy_memchr((char *)mask, 0, mask_stride, N, &subloopsize, 1); - if (d->decref_src.func(&d->decref_src.context, - &src, &subloopsize, &src_stride, d->decref_src.auxdata) < 0) { + if (d->decref_src.func(NULL, d->decref_src.descr, + src, subloopsize, src_stride, d->decref_src.auxdata) < 0) { return -1; } dst += subloopsize * dst_stride; @@ -2608,49 +2604,6 @@ _cast_no_op( /* - * Get a strided loop used for clearing. This looks up the `clearimpl` - * slot on the DType. - * The function will error when `clearimpl` is not defined. Note that - * old-style user-dtypes use the "void" version (m) - */ -static int -get_clear_function( - int aligned, npy_intp stride, - PyArray_Descr *dtype, NPY_cast_info *cast_info, - NPY_ARRAYMETHOD_FLAGS *flags) -{ - NPY_cast_info_init(cast_info); - - PyArrayMethodObject *clearimpl = NPY_DT_SLOTS(NPY_DTYPE(dtype))->clearimpl; - if (clearimpl == NULL) { - PyErr_Format(PyExc_RuntimeError, - "Internal error, tried to fetch decref function for the " - "unsupported DType '%S'.", dtype); - return -1; - } - - /* Make sure all important fields are either set or cleared */ - Py_INCREF(dtype); - cast_info->descriptors[0] = dtype; - cast_info->descriptors[1] = NULL; - Py_INCREF(clearimpl); - cast_info->context.method = clearimpl; - cast_info->context.caller = NULL; - - if (clearimpl->get_strided_loop( - &cast_info->context, aligned, 0, &stride, - &cast_info->func, &cast_info->auxdata, flags) < 0) { - Py_CLEAR(cast_info->descriptors[0]); - Py_CLEAR(cast_info->context.method); - NPY_cast_info_xfree(cast_info); - return -1; - } - - return 0; -} - - -/* * ********************* Generalized Multistep Cast ************************ * * New general purpose multiple step cast function when resolve descriptors @@ -2935,8 +2888,7 @@ _clear_cast_info_after_get_loop_failure(NPY_cast_info *cast_info) /* - * Helper for PyArray_GetDTypeTransferFunction, which fetches a single - * transfer function from the each casting implementation (ArrayMethod). + * Fetches a transfer function from the casting implementation (ArrayMethod). * May set the transfer function to NULL when the cast can be achieved using * a view. * TODO: Expand the view functionality for general offsets, not just 0: @@ -2958,14 +2910,16 @@ _clear_cast_info_after_get_loop_failure(NPY_cast_info *cast_info) * * Returns -1 on failure, 0 on success */ -static int -define_cast_for_descrs( +NPY_NO_EXPORT int +PyArray_GetDTypeTransferFunction( int aligned, npy_intp src_stride, npy_intp dst_stride, PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype, int move_references, NPY_cast_info *cast_info, NPY_ARRAYMETHOD_FLAGS *out_flags) { + assert(dst_dtype != NULL); /* Was previously used for decref */ + /* Storage for all cast info in case multi-step casting is necessary */ _multistep_castdata castdata; /* Initialize funcs to NULL to simplify cleanup on error. */ @@ -3117,46 +3071,6 @@ define_cast_for_descrs( } -NPY_NO_EXPORT int -PyArray_GetDTypeTransferFunction(int aligned, - npy_intp src_stride, npy_intp dst_stride, - PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype, - int move_references, - NPY_cast_info *cast_info, - NPY_ARRAYMETHOD_FLAGS *out_flags) -{ - assert(src_dtype != NULL); - - /* - * If one of the dtypes is NULL, we give back either a src decref - * function or a dst setzero function - * - * TODO: Eventually, we may wish to support user dtype with references - * (including and beyond bare `PyObject *` this may require extending - * the ArrayMethod API and those paths should likely be split out - * from this function.) - */ - if (dst_dtype == NULL) { - assert(move_references); - int res = get_clear_function( - aligned, src_dtype->elsize, src_dtype, cast_info, out_flags); - if (res < 0) { - return NPY_FAIL; - } - return NPY_SUCCEED; - } - - if (define_cast_for_descrs(aligned, - src_stride, dst_stride, - src_dtype, dst_dtype, move_references, - cast_info, out_flags) < 0) { - return NPY_FAIL; - } - - return NPY_SUCCEED; -} - - /* * Internal wrapping of casts that have to be performed in a "single" * function (i.e. not by the generic multi-step-cast), but rely on it @@ -3404,7 +3318,7 @@ PyArray_GetMaskedDTypeTransferFunction(int aligned, /* If the src object will need a DECREF, get a function to handle that */ if (move_references && PyDataType_REFCHK(src_dtype)) { NPY_ARRAYMETHOD_FLAGS clear_flags; - if (get_clear_function( + if (PyArray_GetClearFunction( aligned, src_stride, src_dtype, &data->decref_src, &clear_flags) < 0) { NPY_AUXDATA_FREE((NpyAuxData *)data); @@ -3415,7 +3329,7 @@ PyArray_GetMaskedDTypeTransferFunction(int aligned, &_strided_masked_wrapper_clear_function; } else { - NPY_cast_info_init(&data->decref_src); + NPY_traverse_info_init(&data->decref_src); cast_info->func = (PyArrayMethod_StridedLoop *) &_strided_masked_wrapper_transfer_function; } diff --git a/numpy/core/src/multiarray/dtype_traversal.c b/numpy/core/src/multiarray/dtype_traversal.c index 52feb55fe..dec7a4b7a 100644 --- a/numpy/core/src/multiarray/dtype_traversal.c +++ b/numpy/core/src/multiarray/dtype_traversal.c @@ -14,10 +14,65 @@ #include "dtypemeta.h" #include "dtype_traversal.h" +#include "pyerrors.h" +#include <stdint.h> +/* Same as in dtype_transfer.c */ +#define NPY_LOWLEVEL_BUFFER_BLOCKSIZE 128 -/****************** Python Object clear ***********************/ +/* + * Generic Clear function helpers: + */ + +static int +get_clear_function( + void *traverse_context, PyArray_Descr *dtype, int aligned, + npy_intp stride, NPY_traverse_info *clear_info, + NPY_ARRAYMETHOD_FLAGS *flags) +{ + NPY_traverse_info_init(clear_info); + + get_simple_loop_function *get_clear = NPY_DT_SLOTS(NPY_DTYPE(dtype))->get_clear_loop; + if (get_clear == NULL) { + PyErr_Format(PyExc_RuntimeError, + "Internal error, tried to fetch decref/clear function for the " + "unsupported DType '%S'.", dtype); + return -1; + } + + if (get_clear(traverse_context, dtype, aligned, stride, + &clear_info->func, &clear_info->auxdata, flags) < 0) { + /* callee should clean up, but make sure outside debug mode */ + assert(clear_info->func == NULL); + clear_info->func = NULL; + return -1; + } + Py_INCREF(dtype); + clear_info->descr = dtype; + + return 0; +} +/* + * Helper to set up a strided loop used for clearing. + * The function will error when called on a dtype which does not have + * references (and thus the get_clear_loop slot NULL). + * Note that old-style user-dtypes use the "void" version. + * + * NOTE: This function may have a use for a `traverse_context` at some point + * but right now, it is always NULL and only exists to allow adding it + * in the future without changing the strided-loop signature. + */ +NPY_NO_EXPORT int +PyArray_GetClearFunction( + int aligned, npy_intp stride, PyArray_Descr *dtype, + NPY_traverse_info *clear_info, NPY_ARRAYMETHOD_FLAGS *flags) +{ + return get_clear_function(NULL, dtype, aligned, stride, clear_info, flags); +} + + +/****************** Python Object clear ***********************/ static int clear_object_strided_loop( @@ -82,7 +137,7 @@ fields_clear_data_free(NpyAuxData *data) fields_clear_data *d = (fields_clear_data *)data; for (npy_intp i = 0; i < d->field_count; ++i) { - NPY_traverse_info_free(&d->fields[i].info); + NPY_traverse_info_xfree(&d->fields[i].info); } PyMem_Free(d); } @@ -125,16 +180,55 @@ fields_clear_data_clone(NpyAuxData *data) static int +traverse_fields_function( + void *traverse_context, PyArray_Descr *NPY_UNUSED(descr), + char *data, npy_intp N, npy_intp stride, + NpyAuxData *auxdata) +{ + fields_clear_data *d = (fields_clear_data *)auxdata; + npy_intp i, field_count = d->field_count; + + /* Do the traversing a block at a time for better memory caching */ + const npy_intp blocksize = NPY_LOWLEVEL_BUFFER_BLOCKSIZE; + + for (;;) { + if (N > blocksize) { + for (i = 0; i < field_count; ++i) { + single_field_clear_data field = d->fields[i]; + if (field.info.func(traverse_context, + field.info.descr, data + field.src_offset, + blocksize, stride, field.info.auxdata) < 0) { + return -1; + } + } + N -= NPY_LOWLEVEL_BUFFER_BLOCKSIZE; + data += NPY_LOWLEVEL_BUFFER_BLOCKSIZE * stride; + } + else { + for (i = 0; i < field_count; ++i) { + single_field_clear_data field = d->fields[i]; + if (field.info.func(traverse_context, + field.info.descr, data + field.src_offset, + blocksize, stride, field.info.auxdata) < 0) { + return -1; + } + } + return 0; + } + } +} + + +static int get_clear_fields_transfer_function( - void *NPY_UNUSED(traverse_context), PyArray_Descr *dtype, - npy_intp stride, simple_loop_function **out_func, - NpyAuxData **out_transferdata, NPY_ARRAYMETHOD_FLAGS *flags) + void *traverse_context, int NPY_UNUSED(aligned), + PyArray_Descr *dtype, npy_intp stride, simple_loop_function **out_func, + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags) { PyObject *names, *key, *tup, *title; PyArray_Descr *fld_dtype; npy_int i, structsize; Py_ssize_t field_count; - int src_offset; names = dtype->names; field_count = PyTuple_GET_SIZE(dtype->names); @@ -154,35 +248,131 @@ get_clear_fields_transfer_function( single_field_clear_data *field = data->fields; for (i = 0; i < field_count; ++i) { + int offset; + key = PyTuple_GET_ITEM(names, i); tup = PyDict_GetItem(dtype->fields, key); - if (!PyArg_ParseTuple(tup, "Oi|O", &fld_dtype, - &offset, &title)) { + if (!PyArg_ParseTuple(tup, "Oi|O", &fld_dtype, &offset, &title)) { NPY_AUXDATA_FREE((NpyAuxData *)data); return NPY_FAIL; } if (PyDataType_REFCHK(fld_dtype)) { NPY_ARRAYMETHOD_FLAGS clear_flags; if (get_clear_function( - 0, stride, fld_dtype, - &field->info, &clear_flags) < 0) { + traverse_context, fld_dtype, 0, + stride, &field->info, &clear_flags) < 0) { NPY_AUXDATA_FREE((NpyAuxData *)data); return NPY_FAIL; } *flags = PyArrayMethod_COMBINED_FLAGS(*flags, clear_flags); - field->src_offset = src_offset; + field->src_offset = offset; data->field_count++; field++; } } - *out_stransfer = &_strided_to_strided_field_transfer; - *out_transferdata = (NpyAuxData *)data; + *out_func = &traverse_fields_function; + *out_auxdata = (NpyAuxData *)data; return NPY_SUCCEED; } + +typedef struct { + NpyAuxData base; + npy_intp count; + NPY_traverse_info info; +} subarray_clear_data; + + +/* transfer data free function */ +static void +subarray_clear_data_free(NpyAuxData *data) +{ + subarray_clear_data *d = (subarray_clear_data *)data; + + NPY_traverse_info_xfree(&d->info); + PyMem_Free(d); +} + + +/* transfer data copy function */ +static NpyAuxData * +subarray_clear_data_clone(NpyAuxData *data) +{ + subarray_clear_data *d = (subarray_clear_data *)data; + + /* Allocate the data and populate it */ + subarray_clear_data *newdata = PyMem_Malloc(sizeof(subarray_clear_data)); + if (newdata == NULL) { + return NULL; + } + newdata->count = d->count; + + if (NPY_traverse_info_copy(&newdata->info, &d->info) < 0) { + PyMem_Free(newdata); + return NULL; + } + + return (NpyAuxData *)newdata; +} + + +static int +traverse_subarray_func( + void *traverse_context, PyArray_Descr *NPY_UNUSED(descr), + char *data, npy_intp N, npy_intp stride, + NpyAuxData *auxdata) +{ + subarray_clear_data *subarr_data = (subarray_clear_data *)auxdata; + + simple_loop_function *func = subarr_data->info.func; + PyArray_Descr *sub_descr = subarr_data->info.descr; + npy_intp sub_N = subarr_data->count; + NpyAuxData *sub_auxdata = subarr_data->info.auxdata; + npy_intp sub_stride = sub_descr->elsize; + + while (N--) { + if (func(traverse_context, sub_descr, data, + sub_N, sub_stride, sub_auxdata) < 0) { + return -1; + } + data += stride; + } + return 0; +} + + +static int +get_subarray_clear_func( + void *traverse_context, int aligned, PyArray_Descr *dtype, + npy_intp size, npy_intp stride, simple_loop_function **out_func, + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags) +{ + subarray_clear_data *auxdata = PyMem_Malloc(sizeof(subarray_clear_data)); + if (auxdata == NULL) { + PyErr_NoMemory(); + return -1; + } + + auxdata->count = size; + auxdata->base.free = &subarray_clear_data_free; + auxdata->base.clone = &subarray_clear_data_clone; + + if (get_clear_function( + traverse_context, dtype, aligned, + dtype->elsize, &auxdata->info, flags) < 0) { + PyMem_Free(auxdata); + return -1; + } + *out_func = &traverse_subarray_func; + *out_auxdata = (NpyAuxData *)auxdata; + + return 0; +} + + static int clear_no_op( void *NPY_UNUSED(traverse_context), PyArray_Descr *NPY_UNUSED(descr), @@ -195,36 +385,37 @@ clear_no_op( NPY_NO_EXPORT int npy_get_clear_void_and_legacy_user_dtype_loop( - void *NPY_UNUSED(traverse_context), PyArray_Descr *dtype, + void *traverse_context, int aligned, PyArray_Descr *dtype, npy_intp stride, simple_loop_function **out_func, - NpyAuxData **out_transferdata, NPY_ARRAYMETHOD_FLAGS *flags) + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags) { - /* If there are no references, it's a nop (path should not be hit?) */ + /* + * If there are no references, it's a nop. This path should not be hit + * but structured dtypes are tricky when a dtype which included references + * was sliced to not include any. + */ if (!PyDataType_REFCHK(dtype)) { - *out_loop = &clear_no_op; - *out_transferdata = NULL; + *out_func = &clear_no_op; + *out_auxdata = NULL; assert(0); return 0; } if (PyDataType_HASSUBARRAY(dtype)) { - PyArray_Dims src_shape = {NULL, -1}; - npy_intp src_size; + PyArray_Dims shape = {NULL, -1}; + npy_intp size; - if (!(PyArray_IntpConverter(dtype->subarray->shape, - &src_shape))) { + if (!(PyArray_IntpConverter(dtype->subarray->shape, &shape))) { PyErr_SetString(PyExc_ValueError, "invalid subarray shape"); return -1; } - src_size = PyArray_MultiplyList(src_shape.ptr, src_shape.len); - npy_free_cache_dim_obj(src_shape); - - if (get_n_to_n_transfer_function(aligned, - stride, 0, - dtype->subarray->base, NULL, 1, src_size, - out_loop, out_transferdata, - flags) != NPY_SUCCEED) { + size = PyArray_MultiplyList(shape.ptr, shape.len); + npy_free_cache_dim_obj(shape); + + if (get_subarray_clear_func( + traverse_context, aligned, dtype->subarray->base, size, stride, + out_func, out_auxdata, flags) != NPY_SUCCEED) { return -1; } @@ -233,8 +424,8 @@ npy_get_clear_void_and_legacy_user_dtype_loop( /* If there are fields, need to do each field */ else if (PyDataType_HASFIELDS(dtype)) { if (get_clear_fields_transfer_function( - aligned, stride, dtype, - out_loop, out_transferdata, flags) < 0) { + traverse_context, aligned, dtype, stride, + out_func, out_auxdata, flags) < 0) { return -1; } return 0; diff --git a/numpy/core/src/multiarray/dtype_traversal.h b/numpy/core/src/multiarray/dtype_traversal.h index 19c3a6de1..919c49eb7 100644 --- a/numpy/core/src/multiarray/dtype_traversal.h +++ b/numpy/core/src/multiarray/dtype_traversal.h @@ -33,7 +33,8 @@ typedef int (simple_loop_function)( /* Simplified get_loop function specific to dtype traversal */ typedef int (get_simple_loop_function)( - void *traverse_context, int aligned, npy_intp fixed_stride, + void *traverse_context, PyArray_Descr *descr, + int aligned, npy_intp fixed_stride, simple_loop_function **out_loop, NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags); @@ -64,8 +65,18 @@ typedef struct { static inline void -NPY_traverse_info_free(NPY_traverse_info *traverse_info) +NPY_traverse_info_init(NPY_traverse_info *cast_info) { + cast_info->func = NULL; /* mark as uninitialized. */ +} + + +static inline void +NPY_traverse_info_xfree(NPY_traverse_info *traverse_info) +{ + if (traverse_info->func == NULL) { + return; + } traverse_info->func = NULL; NPY_AUXDATA_FREE(traverse_info->auxdata); Py_DECREF(traverse_info->descr); @@ -91,4 +102,10 @@ NPY_traverse_info_copy( } +NPY_NO_EXPORT int +PyArray_GetClearFunction( + int aligned, npy_intp stride, PyArray_Descr *dtype, + NPY_traverse_info *clear_info, NPY_ARRAYMETHOD_FLAGS *flags); + + #endif /* NUMPY_CORE_SRC_MULTIARRAY_DTYPE_TRAVERSAL_H_ */
\ No newline at end of file diff --git a/numpy/core/src/multiarray/refcount.c b/numpy/core/src/multiarray/refcount.c index daa5bb289..bb2fb982f 100644 --- a/numpy/core/src/multiarray/refcount.c +++ b/numpy/core/src/multiarray/refcount.c @@ -3,7 +3,7 @@ * section in the numpy reference for C-API. */ #include "array_method.h" -#include "dtype_transfer.h" +#include "dtype_traversal.h" #include "lowlevel_strided_loops.h" #include "pyerrors.h" #define NPY_NO_DEPRECATED_API NPY_API_VERSION @@ -43,16 +43,16 @@ PyArray_ClearData( return 0; } - NPY_cast_info cast_info; + NPY_traverse_info clear_info; NPY_ARRAYMETHOD_FLAGS flags; - if (PyArray_GetDTypeTransferFunction( - aligned, stride, 0, descr, NULL, 1, &cast_info, &flags) < 0) { + if (PyArray_GetClearFunction( + aligned, stride, descr, &clear_info, &flags) < 0) { return -1; } - int res = cast_info.func( - &cast_info.context, &data, &size, &stride, cast_info.auxdata); - NPY_cast_info_xfree(&cast_info); + int res = clear_info.func( + NULL, clear_info.descr, data, size, stride, clear_info.auxdata); + NPY_traverse_info_xfree(&clear_info); return res; } |