summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-04-20 12:44:07 +0200
committerGitHub <noreply@github.com>2023-04-20 12:44:07 +0200
commite20f11036bb7ce9f8de91eb4240e49ea4e41ef17 (patch)
tree7404cf58ed0d4c176ff94172c65bd305ce9bdfba /numpy/core/src
parentd0d0698184de6870fbd2f3aec5ed801e424b3598 (diff)
parent6a7bf10722318143d4b016e8ba0a44e426a535dd (diff)
downloadnumpy-e20f11036bb7ce9f8de91eb4240e49ea4e41ef17.tar.gz
Merge pull request #23591 from ngoldbaum/refactor-zero
ENH: refactor zero-filling and expose dtype API slot for it
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/multiarray/common.c18
-rw-r--r--numpy/core/src/multiarray/common.h3
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c18
-rw-r--r--numpy/core/src/multiarray/ctors.c61
-rw-r--r--numpy/core/src/multiarray/dtype_traversal.c74
-rw-r--r--numpy/core/src/multiarray/dtype_traversal.h21
-rw-r--r--numpy/core/src/multiarray/dtypemeta.c9
-rw-r--r--numpy/core/src/multiarray/dtypemeta.h19
-rw-r--r--numpy/core/src/multiarray/getset.c19
-rw-r--r--numpy/core/src/multiarray/refcount.c7
-rw-r--r--numpy/core/src/multiarray/refcount.h3
11 files changed, 177 insertions, 75 deletions
diff --git a/numpy/core/src/multiarray/common.c b/numpy/core/src/multiarray/common.c
index da8d23a26..001d299c7 100644
--- a/numpy/core/src/multiarray/common.c
+++ b/numpy/core/src/multiarray/common.c
@@ -128,24 +128,6 @@ PyArray_DTypeFromObject(PyObject *obj, int maxdims, PyArray_Descr **out_dtype)
}
-NPY_NO_EXPORT int
-_zerofill(PyArrayObject *ret)
-{
- if (PyDataType_REFCHK(PyArray_DESCR(ret))) {
- PyObject *zero = PyLong_FromLong(0);
- PyArray_FillObjectArray(ret, zero);
- Py_DECREF(zero);
- if (PyErr_Occurred()) {
- return -1;
- }
- }
- else {
- npy_intp n = PyArray_NBYTES(ret);
- memset(PyArray_DATA(ret), 0, n);
- }
- return 0;
-}
-
NPY_NO_EXPORT npy_bool
_IsWriteable(PyArrayObject *ap)
{
diff --git a/numpy/core/src/multiarray/common.h b/numpy/core/src/multiarray/common.h
index 4e067b22c..127c6250d 100644
--- a/numpy/core/src/multiarray/common.h
+++ b/numpy/core/src/multiarray/common.h
@@ -55,9 +55,6 @@ PyArray_DTypeFromObject(PyObject *obj, int maxdims,
NPY_NO_EXPORT PyArray_Descr *
_array_find_python_scalar_type(PyObject *op);
-NPY_NO_EXPORT int
-_zerofill(PyArrayObject *ret);
-
NPY_NO_EXPORT npy_bool
_IsWriteable(PyArrayObject *ap);
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index 53db5c577..de1cd075d 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -3913,21 +3913,6 @@ PyArray_InitializeObjectToObjectCast(void)
}
-static int
-PyArray_SetClearFunctions(void)
-{
- PyArray_DTypeMeta *Object = PyArray_DTypeFromTypeNum(NPY_OBJECT);
- NPY_DT_SLOTS(Object)->get_clear_loop = &npy_get_clear_object_strided_loop;
- Py_DECREF(Object); /* use borrowed */
-
- PyArray_DTypeMeta *Void = PyArray_DTypeFromTypeNum(NPY_VOID);
- NPY_DT_SLOTS(Void)->get_clear_loop = &npy_get_clear_void_and_legacy_user_dtype_loop;
- Py_DECREF(Void); /* use borrowed */
- return 0;
-}
-
-
-
NPY_NO_EXPORT int
PyArray_InitializeCasts()
{
@@ -3947,8 +3932,5 @@ PyArray_InitializeCasts()
if (PyArray_InitializeDatetimeCasts() < 0) {
return -1;
}
- if (PyArray_SetClearFunctions() < 0) {
- return -1;
- }
return 0;
}
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c
index a6335c783..79a1905a7 100644
--- a/numpy/core/src/multiarray/ctors.c
+++ b/numpy/core/src/multiarray/ctors.c
@@ -715,6 +715,11 @@ PyArray_NewFromDescr_int(
fa->base = (PyObject *)NULL;
fa->weakreflist = (PyObject *)NULL;
+ /* needed for zero-filling logic below, defined and initialized up here
+ so cleanup logic can go in the fail block */
+ NPY_traverse_info fill_zero_info;
+ NPY_traverse_info_init(&fill_zero_info);
+
if (nd > 0) {
fa->dimensions = npy_alloc_cache_dim(2 * nd);
if (fa->dimensions == NULL) {
@@ -784,6 +789,31 @@ PyArray_NewFromDescr_int(
if (data == NULL) {
+ /* float errors do not matter and we do not release GIL */
+ NPY_ARRAYMETHOD_FLAGS zero_flags;
+ get_traverse_loop_function *get_fill_zero_loop =
+ NPY_DT_SLOTS(NPY_DTYPE(descr))->get_fill_zero_loop;
+ if (get_fill_zero_loop != NULL) {
+ if (get_fill_zero_loop(
+ NULL, descr, 1, descr->elsize, &(fill_zero_info.func),
+ &(fill_zero_info.auxdata), &zero_flags) < 0) {
+ goto fail;
+ }
+ }
+
+ /*
+ * We always want a zero-filled array allocated with calloc if
+ * NPY_NEEDS_INIT is set on the dtype, for safety. We also want a
+ * zero-filled array if zeroed is set and the zero-filling loop isn't
+ * defined, for better performance.
+ *
+ * If the zero-filling loop is defined and zeroed is set, allocate
+ * with malloc and let the zero-filling loop fill the array buffer
+ * with valid zero values for the dtype.
+ */
+ int use_calloc = (PyDataType_FLAGCHK(descr, NPY_NEEDS_INIT) ||
+ (zeroed && (fill_zero_info.func == NULL)));
+
/* Store the handler in case the default is modified */
fa->mem_handler = PyDataMem_GetHandler();
if (fa->mem_handler == NULL) {
@@ -801,11 +831,8 @@ PyArray_NewFromDescr_int(
fa->strides[i] = 0;
}
}
- /*
- * It is bad to have uninitialized OBJECT pointers
- * which could also be sub-fields of a VOID array
- */
- if (zeroed || PyDataType_FLAGCHK(descr, NPY_NEEDS_INIT)) {
+
+ if (use_calloc) {
data = PyDataMem_UserNEW_ZEROED(nbytes, 1, fa->mem_handler);
}
else {
@@ -816,6 +843,18 @@ PyArray_NewFromDescr_int(
goto fail;
}
+ /*
+ * If the array needs special dtype-specific zero-filling logic, do that
+ */
+ if (NPY_UNLIKELY(zeroed && (fill_zero_info.func != NULL))) {
+ npy_intp size = PyArray_MultiplyList(fa->dimensions, fa->nd);
+ if (fill_zero_info.func(
+ NULL, descr, data, size, descr->elsize,
+ fill_zero_info.auxdata) < 0) {
+ goto fail;
+ }
+ }
+
fa->flags |= NPY_ARRAY_OWNDATA;
}
else {
@@ -910,9 +949,11 @@ PyArray_NewFromDescr_int(
}
}
}
+ NPY_traverse_info_xfree(&fill_zero_info);
return (PyObject *)fa;
fail:
+ NPY_traverse_info_xfree(&fill_zero_info);
Py_XDECREF(fa->mem_handler);
Py_DECREF(fa);
return NULL;
@@ -3017,17 +3058,7 @@ PyArray_Zeros(int nd, npy_intp const *dims, PyArray_Descr *type, int is_f_order)
return NULL;
}
- /* handle objects */
- if (PyDataType_REFCHK(PyArray_DESCR(ret))) {
- if (_zerofill(ret) < 0) {
- Py_DECREF(ret);
- return NULL;
- }
- }
-
-
return (PyObject *)ret;
-
}
/*NUMPY_API
diff --git a/numpy/core/src/multiarray/dtype_traversal.c b/numpy/core/src/multiarray/dtype_traversal.c
index cefa7d6e1..769c2e015 100644
--- a/numpy/core/src/multiarray/dtype_traversal.c
+++ b/numpy/core/src/multiarray/dtype_traversal.c
@@ -24,7 +24,7 @@
#include "alloc.h"
#include "array_method.h"
#include "dtypemeta.h"
-
+#include "refcount.h"
#include "dtype_traversal.h"
@@ -124,6 +124,39 @@ npy_get_clear_object_strided_loop(
}
+/**************** Python Object zero fill *********************/
+
+static int
+fill_zero_object_strided_loop(
+ void *NPY_UNUSED(traverse_context), PyArray_Descr *NPY_UNUSED(descr),
+ char *data, npy_intp size, npy_intp stride,
+ NpyAuxData *NPY_UNUSED(auxdata))
+{
+ PyObject *zero = PyLong_FromLong(0);
+ while (size--) {
+ Py_INCREF(zero);
+ // assumes `data` doesn't have a pre-existing object inside it
+ memcpy(data, &zero, sizeof(zero));
+ data += stride;
+ }
+ Py_DECREF(zero);
+ return 0;
+}
+
+NPY_NO_EXPORT int
+npy_object_get_fill_zero_loop(void *NPY_UNUSED(traverse_context),
+ PyArray_Descr *NPY_UNUSED(descr),
+ int NPY_UNUSED(aligned),
+ npy_intp NPY_UNUSED(fixed_stride),
+ traverse_loop_function **out_loop,
+ NpyAuxData **NPY_UNUSED(out_auxdata),
+ NPY_ARRAYMETHOD_FLAGS *flags)
+{
+ *flags = NPY_METH_REQUIRES_PYAPI | NPY_METH_NO_FLOATINGPOINT_ERRORS;
+ *out_loop = &fill_zero_object_strided_loop;
+ return 0;
+}
+
/**************** Structured DType clear funcationality ***************/
/*
@@ -408,7 +441,6 @@ clear_no_op(
return 0;
}
-
NPY_NO_EXPORT int
npy_get_clear_void_and_legacy_user_dtype_loop(
void *traverse_context, PyArray_Descr *dtype, int aligned,
@@ -472,3 +504,41 @@ npy_get_clear_void_and_legacy_user_dtype_loop(
dtype);
return -1;
}
+
+/**************** Structured DType zero fill ***************/
+
+static int
+fill_zero_void_with_objects_strided_loop(
+ void *NPY_UNUSED(traverse_context), PyArray_Descr *descr,
+ char *data, npy_intp size, npy_intp stride,
+ NpyAuxData *NPY_UNUSED(auxdata))
+{
+ PyObject *zero = PyLong_FromLong(0);
+ while (size--) {
+ _fillobject(data, zero, descr);
+ data += stride;
+ }
+ Py_DECREF(zero);
+ return 0;
+}
+
+
+NPY_NO_EXPORT int
+npy_void_get_fill_zero_loop(void *NPY_UNUSED(traverse_context),
+ PyArray_Descr *descr,
+ int NPY_UNUSED(aligned),
+ npy_intp NPY_UNUSED(fixed_stride),
+ traverse_loop_function **out_loop,
+ NpyAuxData **NPY_UNUSED(out_auxdata),
+ NPY_ARRAYMETHOD_FLAGS *flags)
+{
+ *flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
+ if (PyDataType_REFCHK(descr)) {
+ *flags |= NPY_METH_REQUIRES_PYAPI;
+ *out_loop = &fill_zero_void_with_objects_strided_loop;
+ }
+ else {
+ *out_loop = NULL;
+ }
+ return 0;
+}
diff --git a/numpy/core/src/multiarray/dtype_traversal.h b/numpy/core/src/multiarray/dtype_traversal.h
index fd060a0f0..a9c185382 100644
--- a/numpy/core/src/multiarray/dtype_traversal.h
+++ b/numpy/core/src/multiarray/dtype_traversal.h
@@ -19,6 +19,22 @@ npy_get_clear_void_and_legacy_user_dtype_loop(
traverse_loop_function **out_loop, NpyAuxData **out_traversedata,
NPY_ARRAYMETHOD_FLAGS *flags);
+/* NumPy DType zero-filling implementations */
+
+NPY_NO_EXPORT int
+npy_object_get_fill_zero_loop(
+ void *NPY_UNUSED(traverse_context), PyArray_Descr *NPY_UNUSED(descr),
+ int NPY_UNUSED(aligned), npy_intp NPY_UNUSED(fixed_stride),
+ traverse_loop_function **out_loop, NpyAuxData **NPY_UNUSED(out_auxdata),
+ NPY_ARRAYMETHOD_FLAGS *flags);
+
+NPY_NO_EXPORT int
+npy_void_get_fill_zero_loop(
+ void *NPY_UNUSED(traverse_context), PyArray_Descr *descr,
+ int NPY_UNUSED(aligned), npy_intp NPY_UNUSED(fixed_stride),
+ traverse_loop_function **out_loop, NpyAuxData **NPY_UNUSED(out_auxdata),
+ NPY_ARRAYMETHOD_FLAGS *flags);
+
/* Helper to deal with calling or nesting simple strided loops */
@@ -34,6 +50,7 @@ NPY_traverse_info_init(NPY_traverse_info *cast_info)
{
cast_info->func = NULL; /* mark as uninitialized. */
cast_info->auxdata = NULL; /* allow leaving auxdata untouched */
+ cast_info->descr = NULL; /* mark as uninitialized. */
}
@@ -45,7 +62,7 @@ NPY_traverse_info_xfree(NPY_traverse_info *traverse_info)
}
traverse_info->func = NULL;
NPY_AUXDATA_FREE(traverse_info->auxdata);
- Py_DECREF(traverse_info->descr);
+ Py_XDECREF(traverse_info->descr);
}
@@ -79,4 +96,4 @@ PyArray_GetClearFunction(
NPY_traverse_info *clear_info, NPY_ARRAYMETHOD_FLAGS *flags);
-#endif /* NUMPY_CORE_SRC_MULTIARRAY_DTYPE_TRAVERSAL_H_ */ \ No newline at end of file
+#endif /* NUMPY_CORE_SRC_MULTIARRAY_DTYPE_TRAVERSAL_H_ */
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c
index f268ba2cb..f8c1b6617 100644
--- a/numpy/core/src/multiarray/dtypemeta.c
+++ b/numpy/core/src/multiarray/dtypemeta.c
@@ -20,6 +20,8 @@
#include "usertypes.h"
#include "conversion_utils.h"
#include "templ_common.h"
+#include "refcount.h"
+#include "dtype_traversal.h"
#include <assert.h>
@@ -524,6 +526,7 @@ void_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2)
return NULL;
}
+
NPY_NO_EXPORT int
python_builtins_are_known_scalar_types(
PyArray_DTypeMeta *NPY_UNUSED(cls), PyTypeObject *pytype)
@@ -855,6 +858,7 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
dt_slots->common_dtype = default_builtin_common_dtype;
dt_slots->common_instance = NULL;
dt_slots->ensure_canonical = ensure_native_byteorder;
+ dt_slots->get_fill_zero_loop = NULL;
if (PyTypeNum_ISSIGNED(dtype_class->type_num)) {
/* Convert our scalars (raise on too large unsigned and NaN, etc.) */
@@ -866,6 +870,8 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
}
else if (descr->type_num == NPY_OBJECT) {
dt_slots->common_dtype = object_common_dtype;
+ dt_slots->get_fill_zero_loop = npy_object_get_fill_zero_loop;
+ dt_slots->get_clear_loop = npy_get_clear_object_strided_loop;
}
else if (PyTypeNum_ISDATETIME(descr->type_num)) {
/* Datetimes are flexible, but were not considered previously */
@@ -887,6 +893,9 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
void_discover_descr_from_pyobject);
dt_slots->common_instance = void_common_instance;
dt_slots->ensure_canonical = void_ensure_canonical;
+ dt_slots->get_fill_zero_loop = npy_void_get_fill_zero_loop;
+ dt_slots->get_clear_loop =
+ npy_get_clear_void_and_legacy_user_dtype_loop;
}
else {
dt_slots->default_descr = string_and_unicode_default_descr;
diff --git a/numpy/core/src/multiarray/dtypemeta.h b/numpy/core/src/multiarray/dtypemeta.h
index 3b4dbad24..6dbfd1549 100644
--- a/numpy/core/src/multiarray/dtypemeta.h
+++ b/numpy/core/src/multiarray/dtypemeta.h
@@ -42,6 +42,23 @@ typedef struct {
*/
get_traverse_loop_function *get_clear_loop;
/*
+ Either NULL or a function that sets a function pointer to a traversal
+ loop that fills an array with zero values appropriate for the dtype. If
+ get_fill_zero_loop is undefined or the function pointer set by it is
+ NULL, the array buffer is allocated with calloc. If this function is
+ defined and it sets a non-NULL function pointer, the array buffer is
+ allocated with malloc and the zero-filling loop function pointer is
+ called to fill the buffer. For the best performance, avoid using this
+ function if a zero-filled array buffer allocated with calloc makes sense
+ for the dtype.
+
+ Note that this is currently used only for zero-filling a newly allocated
+ array. While it can be used to zero-fill an already-filled buffer, that
+ will not work correctly for arrays holding references. If you need to do
+ that, clear the array first.
+ */
+ get_traverse_loop_function *get_fill_zero_loop;
+ /*
* The casting implementation (ArrayMethod) to convert between two
* instances of this DType, stored explicitly for fast access:
*/
@@ -63,7 +80,7 @@ typedef struct {
// This must be updated if new slots before within_dtype_castingimpl
// are added
-#define NPY_NUM_DTYPE_SLOTS 9
+#define NPY_NUM_DTYPE_SLOTS 10
#define NPY_NUM_DTYPE_PYARRAY_ARRFUNCS_SLOTS 22
#define NPY_DT_MAX_ARRFUNCS_SLOT \
NPY_NUM_DTYPE_PYARRAY_ARRFUNCS_SLOTS + _NPY_DT_ARRFUNCS_OFFSET
diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c
index ab35548ed..d019acbb5 100644
--- a/numpy/core/src/multiarray/getset.c
+++ b/numpy/core/src/multiarray/getset.c
@@ -797,20 +797,17 @@ array_imag_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
}
else {
Py_INCREF(PyArray_DESCR(self));
- ret = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(self),
- PyArray_DESCR(self),
- PyArray_NDIM(self),
- PyArray_DIMS(self),
- NULL, NULL,
- PyArray_ISFORTRAN(self),
- (PyObject *)self);
+ ret = (PyArrayObject *)PyArray_NewFromDescr_int(
+ Py_TYPE(self),
+ PyArray_DESCR(self),
+ PyArray_NDIM(self),
+ PyArray_DIMS(self),
+ NULL, NULL,
+ PyArray_ISFORTRAN(self),
+ (PyObject *)self, NULL, 1, 0);
if (ret == NULL) {
return NULL;
}
- if (_zerofill(ret) < 0) {
- Py_DECREF(ret);
- return NULL;
- }
PyArray_CLEARFLAGS(ret, NPY_ARRAY_WRITEABLE);
}
return (PyObject *) ret;
diff --git a/numpy/core/src/multiarray/refcount.c b/numpy/core/src/multiarray/refcount.c
index 20527f7af..d200957c3 100644
--- a/numpy/core/src/multiarray/refcount.c
+++ b/numpy/core/src/multiarray/refcount.c
@@ -17,15 +17,12 @@
#include "numpy/arrayscalars.h"
#include "iterators.h"
#include "dtypemeta.h"
+#include "refcount.h"
#include "npy_config.h"
#include "npy_pycompat.h"
-static void
-_fillobject(char *optr, PyObject *obj, PyArray_Descr *dtype);
-
-
/*
* Helper function to clear a strided memory (normally or always contiguous)
* from all Python (or other) references. The function does nothing if the
@@ -395,7 +392,7 @@ PyArray_FillObjectArray(PyArrayObject *arr, PyObject *obj)
}
}
-static void
+NPY_NO_EXPORT void
_fillobject(char *optr, PyObject *obj, PyArray_Descr *dtype)
{
if (!PyDataType_FLAGCHK(dtype, NPY_ITEM_REFCOUNT)) {
diff --git a/numpy/core/src/multiarray/refcount.h b/numpy/core/src/multiarray/refcount.h
index 16d34e292..7f39b9ca4 100644
--- a/numpy/core/src/multiarray/refcount.h
+++ b/numpy/core/src/multiarray/refcount.h
@@ -24,4 +24,7 @@ PyArray_XDECREF(PyArrayObject *mp);
NPY_NO_EXPORT void
PyArray_FillObjectArray(PyArrayObject *arr, PyObject *obj);
+NPY_NO_EXPORT void
+_fillobject(char *optr, PyObject *obj, PyArray_Descr *dtype);
+
#endif /* NUMPY_CORE_SRC_MULTIARRAY_REFCOUNT_H_ */