diff options
author | Sebastian Berg <sebastianb@nvidia.com> | 2023-04-20 12:44:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-20 12:44:07 +0200 |
commit | e20f11036bb7ce9f8de91eb4240e49ea4e41ef17 (patch) | |
tree | 7404cf58ed0d4c176ff94172c65bd305ce9bdfba /numpy/core/src | |
parent | d0d0698184de6870fbd2f3aec5ed801e424b3598 (diff) | |
parent | 6a7bf10722318143d4b016e8ba0a44e426a535dd (diff) | |
download | numpy-e20f11036bb7ce9f8de91eb4240e49ea4e41ef17.tar.gz |
Merge pull request #23591 from ngoldbaum/refactor-zero
ENH: refactor zero-filling and expose dtype API slot for it
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/multiarray/common.c | 18 | ||||
-rw-r--r-- | numpy/core/src/multiarray/common.h | 3 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 18 | ||||
-rw-r--r-- | numpy/core/src/multiarray/ctors.c | 61 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtype_traversal.c | 74 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtype_traversal.h | 21 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtypemeta.c | 9 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtypemeta.h | 19 | ||||
-rw-r--r-- | numpy/core/src/multiarray/getset.c | 19 | ||||
-rw-r--r-- | numpy/core/src/multiarray/refcount.c | 7 | ||||
-rw-r--r-- | numpy/core/src/multiarray/refcount.h | 3 |
11 files changed, 177 insertions, 75 deletions
diff --git a/numpy/core/src/multiarray/common.c b/numpy/core/src/multiarray/common.c index da8d23a26..001d299c7 100644 --- a/numpy/core/src/multiarray/common.c +++ b/numpy/core/src/multiarray/common.c @@ -128,24 +128,6 @@ PyArray_DTypeFromObject(PyObject *obj, int maxdims, PyArray_Descr **out_dtype) } -NPY_NO_EXPORT int -_zerofill(PyArrayObject *ret) -{ - if (PyDataType_REFCHK(PyArray_DESCR(ret))) { - PyObject *zero = PyLong_FromLong(0); - PyArray_FillObjectArray(ret, zero); - Py_DECREF(zero); - if (PyErr_Occurred()) { - return -1; - } - } - else { - npy_intp n = PyArray_NBYTES(ret); - memset(PyArray_DATA(ret), 0, n); - } - return 0; -} - NPY_NO_EXPORT npy_bool _IsWriteable(PyArrayObject *ap) { diff --git a/numpy/core/src/multiarray/common.h b/numpy/core/src/multiarray/common.h index 4e067b22c..127c6250d 100644 --- a/numpy/core/src/multiarray/common.h +++ b/numpy/core/src/multiarray/common.h @@ -55,9 +55,6 @@ PyArray_DTypeFromObject(PyObject *obj, int maxdims, NPY_NO_EXPORT PyArray_Descr * _array_find_python_scalar_type(PyObject *op); -NPY_NO_EXPORT int -_zerofill(PyArrayObject *ret); - NPY_NO_EXPORT npy_bool _IsWriteable(PyArrayObject *ap); diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index 53db5c577..de1cd075d 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -3913,21 +3913,6 @@ PyArray_InitializeObjectToObjectCast(void) } -static int -PyArray_SetClearFunctions(void) -{ - PyArray_DTypeMeta *Object = PyArray_DTypeFromTypeNum(NPY_OBJECT); - NPY_DT_SLOTS(Object)->get_clear_loop = &npy_get_clear_object_strided_loop; - Py_DECREF(Object); /* use borrowed */ - - PyArray_DTypeMeta *Void = PyArray_DTypeFromTypeNum(NPY_VOID); - NPY_DT_SLOTS(Void)->get_clear_loop = &npy_get_clear_void_and_legacy_user_dtype_loop; - Py_DECREF(Void); /* use borrowed */ - return 0; -} - - - NPY_NO_EXPORT int PyArray_InitializeCasts() { @@ -3947,8 +3932,5 @@ PyArray_InitializeCasts() if (PyArray_InitializeDatetimeCasts() < 0) { return -1; } - if (PyArray_SetClearFunctions() < 0) { - return -1; - } return 0; } diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c index a6335c783..79a1905a7 100644 --- a/numpy/core/src/multiarray/ctors.c +++ b/numpy/core/src/multiarray/ctors.c @@ -715,6 +715,11 @@ PyArray_NewFromDescr_int( fa->base = (PyObject *)NULL; fa->weakreflist = (PyObject *)NULL; + /* needed for zero-filling logic below, defined and initialized up here + so cleanup logic can go in the fail block */ + NPY_traverse_info fill_zero_info; + NPY_traverse_info_init(&fill_zero_info); + if (nd > 0) { fa->dimensions = npy_alloc_cache_dim(2 * nd); if (fa->dimensions == NULL) { @@ -784,6 +789,31 @@ PyArray_NewFromDescr_int( if (data == NULL) { + /* float errors do not matter and we do not release GIL */ + NPY_ARRAYMETHOD_FLAGS zero_flags; + get_traverse_loop_function *get_fill_zero_loop = + NPY_DT_SLOTS(NPY_DTYPE(descr))->get_fill_zero_loop; + if (get_fill_zero_loop != NULL) { + if (get_fill_zero_loop( + NULL, descr, 1, descr->elsize, &(fill_zero_info.func), + &(fill_zero_info.auxdata), &zero_flags) < 0) { + goto fail; + } + } + + /* + * We always want a zero-filled array allocated with calloc if + * NPY_NEEDS_INIT is set on the dtype, for safety. We also want a + * zero-filled array if zeroed is set and the zero-filling loop isn't + * defined, for better performance. + * + * If the zero-filling loop is defined and zeroed is set, allocate + * with malloc and let the zero-filling loop fill the array buffer + * with valid zero values for the dtype. + */ + int use_calloc = (PyDataType_FLAGCHK(descr, NPY_NEEDS_INIT) || + (zeroed && (fill_zero_info.func == NULL))); + /* Store the handler in case the default is modified */ fa->mem_handler = PyDataMem_GetHandler(); if (fa->mem_handler == NULL) { @@ -801,11 +831,8 @@ PyArray_NewFromDescr_int( fa->strides[i] = 0; } } - /* - * It is bad to have uninitialized OBJECT pointers - * which could also be sub-fields of a VOID array - */ - if (zeroed || PyDataType_FLAGCHK(descr, NPY_NEEDS_INIT)) { + + if (use_calloc) { data = PyDataMem_UserNEW_ZEROED(nbytes, 1, fa->mem_handler); } else { @@ -816,6 +843,18 @@ PyArray_NewFromDescr_int( goto fail; } + /* + * If the array needs special dtype-specific zero-filling logic, do that + */ + if (NPY_UNLIKELY(zeroed && (fill_zero_info.func != NULL))) { + npy_intp size = PyArray_MultiplyList(fa->dimensions, fa->nd); + if (fill_zero_info.func( + NULL, descr, data, size, descr->elsize, + fill_zero_info.auxdata) < 0) { + goto fail; + } + } + fa->flags |= NPY_ARRAY_OWNDATA; } else { @@ -910,9 +949,11 @@ PyArray_NewFromDescr_int( } } } + NPY_traverse_info_xfree(&fill_zero_info); return (PyObject *)fa; fail: + NPY_traverse_info_xfree(&fill_zero_info); Py_XDECREF(fa->mem_handler); Py_DECREF(fa); return NULL; @@ -3017,17 +3058,7 @@ PyArray_Zeros(int nd, npy_intp const *dims, PyArray_Descr *type, int is_f_order) return NULL; } - /* handle objects */ - if (PyDataType_REFCHK(PyArray_DESCR(ret))) { - if (_zerofill(ret) < 0) { - Py_DECREF(ret); - return NULL; - } - } - - return (PyObject *)ret; - } /*NUMPY_API diff --git a/numpy/core/src/multiarray/dtype_traversal.c b/numpy/core/src/multiarray/dtype_traversal.c index cefa7d6e1..769c2e015 100644 --- a/numpy/core/src/multiarray/dtype_traversal.c +++ b/numpy/core/src/multiarray/dtype_traversal.c @@ -24,7 +24,7 @@ #include "alloc.h" #include "array_method.h" #include "dtypemeta.h" - +#include "refcount.h" #include "dtype_traversal.h" @@ -124,6 +124,39 @@ npy_get_clear_object_strided_loop( } +/**************** Python Object zero fill *********************/ + +static int +fill_zero_object_strided_loop( + void *NPY_UNUSED(traverse_context), PyArray_Descr *NPY_UNUSED(descr), + char *data, npy_intp size, npy_intp stride, + NpyAuxData *NPY_UNUSED(auxdata)) +{ + PyObject *zero = PyLong_FromLong(0); + while (size--) { + Py_INCREF(zero); + // assumes `data` doesn't have a pre-existing object inside it + memcpy(data, &zero, sizeof(zero)); + data += stride; + } + Py_DECREF(zero); + return 0; +} + +NPY_NO_EXPORT int +npy_object_get_fill_zero_loop(void *NPY_UNUSED(traverse_context), + PyArray_Descr *NPY_UNUSED(descr), + int NPY_UNUSED(aligned), + npy_intp NPY_UNUSED(fixed_stride), + traverse_loop_function **out_loop, + NpyAuxData **NPY_UNUSED(out_auxdata), + NPY_ARRAYMETHOD_FLAGS *flags) +{ + *flags = NPY_METH_REQUIRES_PYAPI | NPY_METH_NO_FLOATINGPOINT_ERRORS; + *out_loop = &fill_zero_object_strided_loop; + return 0; +} + /**************** Structured DType clear funcationality ***************/ /* @@ -408,7 +441,6 @@ clear_no_op( return 0; } - NPY_NO_EXPORT int npy_get_clear_void_and_legacy_user_dtype_loop( void *traverse_context, PyArray_Descr *dtype, int aligned, @@ -472,3 +504,41 @@ npy_get_clear_void_and_legacy_user_dtype_loop( dtype); return -1; } + +/**************** Structured DType zero fill ***************/ + +static int +fill_zero_void_with_objects_strided_loop( + void *NPY_UNUSED(traverse_context), PyArray_Descr *descr, + char *data, npy_intp size, npy_intp stride, + NpyAuxData *NPY_UNUSED(auxdata)) +{ + PyObject *zero = PyLong_FromLong(0); + while (size--) { + _fillobject(data, zero, descr); + data += stride; + } + Py_DECREF(zero); + return 0; +} + + +NPY_NO_EXPORT int +npy_void_get_fill_zero_loop(void *NPY_UNUSED(traverse_context), + PyArray_Descr *descr, + int NPY_UNUSED(aligned), + npy_intp NPY_UNUSED(fixed_stride), + traverse_loop_function **out_loop, + NpyAuxData **NPY_UNUSED(out_auxdata), + NPY_ARRAYMETHOD_FLAGS *flags) +{ + *flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + if (PyDataType_REFCHK(descr)) { + *flags |= NPY_METH_REQUIRES_PYAPI; + *out_loop = &fill_zero_void_with_objects_strided_loop; + } + else { + *out_loop = NULL; + } + return 0; +} diff --git a/numpy/core/src/multiarray/dtype_traversal.h b/numpy/core/src/multiarray/dtype_traversal.h index fd060a0f0..a9c185382 100644 --- a/numpy/core/src/multiarray/dtype_traversal.h +++ b/numpy/core/src/multiarray/dtype_traversal.h @@ -19,6 +19,22 @@ npy_get_clear_void_and_legacy_user_dtype_loop( traverse_loop_function **out_loop, NpyAuxData **out_traversedata, NPY_ARRAYMETHOD_FLAGS *flags); +/* NumPy DType zero-filling implementations */ + +NPY_NO_EXPORT int +npy_object_get_fill_zero_loop( + void *NPY_UNUSED(traverse_context), PyArray_Descr *NPY_UNUSED(descr), + int NPY_UNUSED(aligned), npy_intp NPY_UNUSED(fixed_stride), + traverse_loop_function **out_loop, NpyAuxData **NPY_UNUSED(out_auxdata), + NPY_ARRAYMETHOD_FLAGS *flags); + +NPY_NO_EXPORT int +npy_void_get_fill_zero_loop( + void *NPY_UNUSED(traverse_context), PyArray_Descr *descr, + int NPY_UNUSED(aligned), npy_intp NPY_UNUSED(fixed_stride), + traverse_loop_function **out_loop, NpyAuxData **NPY_UNUSED(out_auxdata), + NPY_ARRAYMETHOD_FLAGS *flags); + /* Helper to deal with calling or nesting simple strided loops */ @@ -34,6 +50,7 @@ NPY_traverse_info_init(NPY_traverse_info *cast_info) { cast_info->func = NULL; /* mark as uninitialized. */ cast_info->auxdata = NULL; /* allow leaving auxdata untouched */ + cast_info->descr = NULL; /* mark as uninitialized. */ } @@ -45,7 +62,7 @@ NPY_traverse_info_xfree(NPY_traverse_info *traverse_info) } traverse_info->func = NULL; NPY_AUXDATA_FREE(traverse_info->auxdata); - Py_DECREF(traverse_info->descr); + Py_XDECREF(traverse_info->descr); } @@ -79,4 +96,4 @@ PyArray_GetClearFunction( NPY_traverse_info *clear_info, NPY_ARRAYMETHOD_FLAGS *flags); -#endif /* NUMPY_CORE_SRC_MULTIARRAY_DTYPE_TRAVERSAL_H_ */
\ No newline at end of file +#endif /* NUMPY_CORE_SRC_MULTIARRAY_DTYPE_TRAVERSAL_H_ */ diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c index f268ba2cb..f8c1b6617 100644 --- a/numpy/core/src/multiarray/dtypemeta.c +++ b/numpy/core/src/multiarray/dtypemeta.c @@ -20,6 +20,8 @@ #include "usertypes.h" #include "conversion_utils.h" #include "templ_common.h" +#include "refcount.h" +#include "dtype_traversal.h" #include <assert.h> @@ -524,6 +526,7 @@ void_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2) return NULL; } + NPY_NO_EXPORT int python_builtins_are_known_scalar_types( PyArray_DTypeMeta *NPY_UNUSED(cls), PyTypeObject *pytype) @@ -855,6 +858,7 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) dt_slots->common_dtype = default_builtin_common_dtype; dt_slots->common_instance = NULL; dt_slots->ensure_canonical = ensure_native_byteorder; + dt_slots->get_fill_zero_loop = NULL; if (PyTypeNum_ISSIGNED(dtype_class->type_num)) { /* Convert our scalars (raise on too large unsigned and NaN, etc.) */ @@ -866,6 +870,8 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) } else if (descr->type_num == NPY_OBJECT) { dt_slots->common_dtype = object_common_dtype; + dt_slots->get_fill_zero_loop = npy_object_get_fill_zero_loop; + dt_slots->get_clear_loop = npy_get_clear_object_strided_loop; } else if (PyTypeNum_ISDATETIME(descr->type_num)) { /* Datetimes are flexible, but were not considered previously */ @@ -887,6 +893,9 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) void_discover_descr_from_pyobject); dt_slots->common_instance = void_common_instance; dt_slots->ensure_canonical = void_ensure_canonical; + dt_slots->get_fill_zero_loop = npy_void_get_fill_zero_loop; + dt_slots->get_clear_loop = + npy_get_clear_void_and_legacy_user_dtype_loop; } else { dt_slots->default_descr = string_and_unicode_default_descr; diff --git a/numpy/core/src/multiarray/dtypemeta.h b/numpy/core/src/multiarray/dtypemeta.h index 3b4dbad24..6dbfd1549 100644 --- a/numpy/core/src/multiarray/dtypemeta.h +++ b/numpy/core/src/multiarray/dtypemeta.h @@ -42,6 +42,23 @@ typedef struct { */ get_traverse_loop_function *get_clear_loop; /* + Either NULL or a function that sets a function pointer to a traversal + loop that fills an array with zero values appropriate for the dtype. If + get_fill_zero_loop is undefined or the function pointer set by it is + NULL, the array buffer is allocated with calloc. If this function is + defined and it sets a non-NULL function pointer, the array buffer is + allocated with malloc and the zero-filling loop function pointer is + called to fill the buffer. For the best performance, avoid using this + function if a zero-filled array buffer allocated with calloc makes sense + for the dtype. + + Note that this is currently used only for zero-filling a newly allocated + array. While it can be used to zero-fill an already-filled buffer, that + will not work correctly for arrays holding references. If you need to do + that, clear the array first. + */ + get_traverse_loop_function *get_fill_zero_loop; + /* * The casting implementation (ArrayMethod) to convert between two * instances of this DType, stored explicitly for fast access: */ @@ -63,7 +80,7 @@ typedef struct { // This must be updated if new slots before within_dtype_castingimpl // are added -#define NPY_NUM_DTYPE_SLOTS 9 +#define NPY_NUM_DTYPE_SLOTS 10 #define NPY_NUM_DTYPE_PYARRAY_ARRFUNCS_SLOTS 22 #define NPY_DT_MAX_ARRFUNCS_SLOT \ NPY_NUM_DTYPE_PYARRAY_ARRFUNCS_SLOTS + _NPY_DT_ARRFUNCS_OFFSET diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c index ab35548ed..d019acbb5 100644 --- a/numpy/core/src/multiarray/getset.c +++ b/numpy/core/src/multiarray/getset.c @@ -797,20 +797,17 @@ array_imag_get(PyArrayObject *self, void *NPY_UNUSED(ignored)) } else { Py_INCREF(PyArray_DESCR(self)); - ret = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(self), - PyArray_DESCR(self), - PyArray_NDIM(self), - PyArray_DIMS(self), - NULL, NULL, - PyArray_ISFORTRAN(self), - (PyObject *)self); + ret = (PyArrayObject *)PyArray_NewFromDescr_int( + Py_TYPE(self), + PyArray_DESCR(self), + PyArray_NDIM(self), + PyArray_DIMS(self), + NULL, NULL, + PyArray_ISFORTRAN(self), + (PyObject *)self, NULL, 1, 0); if (ret == NULL) { return NULL; } - if (_zerofill(ret) < 0) { - Py_DECREF(ret); - return NULL; - } PyArray_CLEARFLAGS(ret, NPY_ARRAY_WRITEABLE); } return (PyObject *) ret; diff --git a/numpy/core/src/multiarray/refcount.c b/numpy/core/src/multiarray/refcount.c index 20527f7af..d200957c3 100644 --- a/numpy/core/src/multiarray/refcount.c +++ b/numpy/core/src/multiarray/refcount.c @@ -17,15 +17,12 @@ #include "numpy/arrayscalars.h" #include "iterators.h" #include "dtypemeta.h" +#include "refcount.h" #include "npy_config.h" #include "npy_pycompat.h" -static void -_fillobject(char *optr, PyObject *obj, PyArray_Descr *dtype); - - /* * Helper function to clear a strided memory (normally or always contiguous) * from all Python (or other) references. The function does nothing if the @@ -395,7 +392,7 @@ PyArray_FillObjectArray(PyArrayObject *arr, PyObject *obj) } } -static void +NPY_NO_EXPORT void _fillobject(char *optr, PyObject *obj, PyArray_Descr *dtype) { if (!PyDataType_FLAGCHK(dtype, NPY_ITEM_REFCOUNT)) { diff --git a/numpy/core/src/multiarray/refcount.h b/numpy/core/src/multiarray/refcount.h index 16d34e292..7f39b9ca4 100644 --- a/numpy/core/src/multiarray/refcount.h +++ b/numpy/core/src/multiarray/refcount.h @@ -24,4 +24,7 @@ PyArray_XDECREF(PyArrayObject *mp); NPY_NO_EXPORT void PyArray_FillObjectArray(PyArrayObject *arr, PyObject *obj); +NPY_NO_EXPORT void +_fillobject(char *optr, PyObject *obj, PyArray_Descr *dtype); + #endif /* NUMPY_CORE_SRC_MULTIARRAY_REFCOUNT_H_ */ |