diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2015-10-18 14:14:54 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-10-18 14:14:54 -0600 |
commit | c3b48b91f151be1bd0f94cb0f05ad2e400dee9b4 (patch) | |
tree | 5ad0e911a555fad56a29f4cec845851c0d838f09 /numpy/core | |
parent | 3ddaa592d855597ae8afe13b3848ac032e693295 (diff) | |
parent | 8cf5b506d2d3da833b09e8bbbe874db6f9c5e809 (diff) | |
download | numpy-c3b48b91f151be1bd0f94cb0f05ad2e400dee9b4.tar.gz |
Merge pull request #6208 from ahaldane/fast_field_subscript
MAINT: Speedup field access by removing unneeded safety checks
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/_internal.py | 54 | ||||
-rw-r--r-- | numpy/core/src/multiarray/arraytypes.c.src | 37 | ||||
-rw-r--r-- | numpy/core/src/multiarray/common.c | 51 | ||||
-rw-r--r-- | numpy/core/src/multiarray/common.h | 21 | ||||
-rw-r--r-- | numpy/core/src/multiarray/getset.c | 24 | ||||
-rw-r--r-- | numpy/core/src/multiarray/mapping.c | 250 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 25 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 117 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 4 |
9 files changed, 375 insertions, 208 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 81f5be4ad..3ddc2c64d 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -288,55 +288,23 @@ def _newnames(datatype, order): return tuple(list(order) + nameslist) raise ValueError("unsupported order value: %s" % (order,)) -def _index_fields(ary, names): - """ Given a structured array and a sequence of field names - construct new array with just those fields. +def _copy_fields(ary): + """Return copy of structured array with padding between fields removed. Parameters ---------- ary : ndarray - Structured array being subscripted - names : string or list of strings - Either a single field name, or a list of field names + Structured array from which to remove padding bytes Returns ------- - sub_ary : ndarray - If `names` is a single field name, the return value is identical to - ary.getfield, a writeable view into `ary`. If `names` is a list of - field names the return value is a copy of `ary` containing only those - fields. This is planned to return a view in the future. - - Raises - ------ - ValueError - If `ary` does not contain a field given in `names`. - + ary_copy : ndarray + Copy of ary with padding bytes removed """ dt = ary.dtype - - #use getfield to index a single field - if isinstance(names, basestring): - try: - return ary.getfield(dt.fields[names][0], dt.fields[names][1]) - except KeyError: - raise ValueError("no field of name %s" % names) - - for name in names: - if name not in dt.fields: - raise ValueError("no field of name %s" % name) - - formats = [dt.fields[name][0] for name in names] - offsets = [dt.fields[name][1] for name in names] - - view_dtype = {'names': names, 'formats': formats, - 'offsets': offsets, 'itemsize': dt.itemsize} - - # return copy for now (future plan to return ary.view(dtype=view_dtype)) - copy_dtype = {'names': view_dtype['names'], - 'formats': view_dtype['formats']} - return array(ary.view(dtype=view_dtype), dtype=copy_dtype, copy=True) - + copy_dtype = {'names': dt.names, + 'formats': [dt.fields[name][0] for name in dt.names]} + return array(ary, dtype=copy_dtype, copy=True) def _get_all_field_offsets(dtype, base_offset=0): """ Returns the types and offsets of all fields in a (possibly structured) @@ -478,6 +446,12 @@ def _view_is_safe(oldtype, newtype): If the new type is incompatible with the old type. """ + + # if the types are equivalent, there is no problem. + # for example: dtype((np.record, 'i4,i4')) == dtype((np.void, 'i4,i4')) + if oldtype == newtype: + return + new_fields = _get_all_field_offsets(newtype) new_size = newtype.itemsize diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index 5aa7e6142..060f25098 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -621,31 +621,6 @@ OBJECT_setitem(PyObject *op, void *ov, void *NPY_UNUSED(ap)) /* VOID */ -/* unpack tuple of dtype->fields (descr, offset, title[not-needed]) */ -static int -unpack_field(PyObject * value, PyArray_Descr ** descr, npy_intp * offset) -{ - PyObject * off; - if (PyTuple_GET_SIZE(value) < 2) { - return -1; - } - *descr = (PyArray_Descr *)PyTuple_GET_ITEM(value, 0); - off = PyTuple_GET_ITEM(value, 1); - - if (PyInt_Check(off)) { - *offset = PyInt_AsSsize_t(off); - } - else if (PyLong_Check(off)) { - *offset = PyLong_AsSsize_t(off); - } - else { - return -1; - } - - return 0; -} - - static PyObject * VOID_getitem(void *input, void *vap) { @@ -674,7 +649,7 @@ VOID_getitem(void *input, void *vap) PyArray_Descr *new; key = PyTuple_GET_ITEM(names, i); tup = PyDict_GetItem(descr->fields, key); - if (unpack_field(tup, &new, &offset) < 0) { + if (_unpack_field(tup, &new, &offset) < 0) { Py_DECREF(ret); ((PyArrayObject_fields *)ap)->descr = descr; return NULL; @@ -811,7 +786,7 @@ VOID_setitem(PyObject *op, void *input, void *vap) npy_intp offset; key = PyTuple_GET_ITEM(names, i); tup = PyDict_GetItem(descr->fields, key); - if (unpack_field(tup, &new, &offset) < 0) { + if (_unpack_field(tup, &new, &offset) < 0) { ((PyArrayObject_fields *)ap)->descr = descr; return -1; } @@ -2178,7 +2153,7 @@ VOID_copyswapn (char *dst, npy_intp dstride, char *src, npy_intp sstride, if (NPY_TITLE_KEY(key, value)) { continue; } - if (unpack_field(value, &new, &offset) < 0) { + if (_unpack_field(value, &new, &offset) < 0) { ((PyArrayObject_fields *)arr)->descr = descr; return; } @@ -2247,7 +2222,7 @@ VOID_copyswap (char *dst, char *src, int swap, PyArrayObject *arr) if (NPY_TITLE_KEY(key, value)) { continue; } - if (unpack_field(value, &new, &offset) < 0) { + if (_unpack_field(value, &new, &offset) < 0) { ((PyArrayObject_fields *)arr)->descr = descr; return; } @@ -2560,7 +2535,7 @@ VOID_nonzero (char *ip, PyArrayObject *ap) if (NPY_TITLE_KEY(key, value)) { continue; } - if (unpack_field(value, &new, &offset) < 0) { + if (_unpack_field(value, &new, &offset) < 0) { PyErr_Clear(); continue; } @@ -2876,7 +2851,7 @@ VOID_compare(char *ip1, char *ip2, PyArrayObject *ap) npy_intp offset; key = PyTuple_GET_ITEM(names, i); tup = PyDict_GetItem(descr->fields, key); - if (unpack_field(tup, &new, &offset) < 0) { + if (_unpack_field(tup, &new, &offset) < 0) { goto finish; } /* descr is the only field checked by compare or copyswap */ diff --git a/numpy/core/src/multiarray/common.c b/numpy/core/src/multiarray/common.c index 3352c3529..1948b8b61 100644 --- a/numpy/core/src/multiarray/common.c +++ b/numpy/core/src/multiarray/common.c @@ -876,3 +876,54 @@ end: Py_XDECREF(shape1_i); Py_XDECREF(shape2_j); } + +/** + * unpack tuple of dtype->fields (descr, offset, title[not-needed]) + * + * @param "value" should be the tuple. + * + * @return "descr" will be set to the field's dtype + * @return "offset" will be set to the field's offset + * + * returns -1 on failure, 0 on success. + */ +NPY_NO_EXPORT int +_unpack_field(PyObject *value, PyArray_Descr **descr, npy_intp *offset) +{ + PyObject * off; + if (PyTuple_GET_SIZE(value) < 2) { + return -1; + } + *descr = (PyArray_Descr *)PyTuple_GET_ITEM(value, 0); + off = PyTuple_GET_ITEM(value, 1); + + if (PyInt_Check(off)) { + *offset = PyInt_AsSsize_t(off); + } + else if (PyLong_Check(off)) { + *offset = PyLong_AsSsize_t(off); + } + else { + return -1; + } + + return 0; +} + +/* + * check whether arrays with datatype dtype might have object fields. This will + * only happen for structured dtypes (which may have hidden objects even if the + * HASOBJECT flag is false), object dtypes, or subarray dtypes whose base type + * is either of these. + */ +NPY_NO_EXPORT int +_may_have_objects(PyArray_Descr *dtype) +{ + PyArray_Descr *base = dtype; + if (PyDataType_HASSUBARRAY(dtype)) { + base = dtype->subarray->base; + } + + return (PyDataType_HASFIELDS(base) || + PyDataType_FLAGCHK(base, NPY_ITEM_HASOBJECT) ); +} diff --git a/numpy/core/src/multiarray/common.h b/numpy/core/src/multiarray/common.h index 11993829f..5e14b80a7 100644 --- a/numpy/core/src/multiarray/common.h +++ b/numpy/core/src/multiarray/common.h @@ -75,6 +75,27 @@ convert_shape_to_string(npy_intp n, npy_intp *vals, char *ending); NPY_NO_EXPORT void dot_alignment_error(PyArrayObject *a, int i, PyArrayObject *b, int j); +/** + * unpack tuple of dtype->fields (descr, offset, title[not-needed]) + * + * @param "value" should be the tuple. + * + * @return "descr" will be set to the field's dtype + * @return "offset" will be set to the field's offset + * + * returns -1 on failure, 0 on success. + */ +NPY_NO_EXPORT int +_unpack_field(PyObject *value, PyArray_Descr **descr, npy_intp *offset); + +/* + * check whether arrays with datatype dtype might have object fields. This will + * only happen for structured dtypes (which may have hidden objects even if the + * HASOBJECT flag is false), object dtypes, or subarray dtypes whose base type + * is either of these. + */ +NPY_NO_EXPORT int +_may_have_objects(PyArray_Descr *dtype); /* * Returns -1 and sets an exception if *index is an invalid index for diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c index 5147b9735..549ea333a 100644 --- a/numpy/core/src/multiarray/getset.c +++ b/numpy/core/src/multiarray/getset.c @@ -438,10 +438,6 @@ array_descr_set(PyArrayObject *self, PyObject *arg) PyObject *safe; static PyObject *checkfunc = NULL; - npy_cache_import("numpy.core._internal", "_view_is_safe", &checkfunc); - if (checkfunc == NULL) { - return -1; - } if (arg == NULL) { PyErr_SetString(PyExc_AttributeError, @@ -456,13 +452,21 @@ array_descr_set(PyArrayObject *self, PyObject *arg) return -1; } - /* check that we are not reinterpreting memory containing Objects */ - safe = PyObject_CallFunction(checkfunc, "OO", PyArray_DESCR(self), newtype); - if (safe == NULL) { - Py_DECREF(newtype); - return -1; + /* check that we are not reinterpreting memory containing Objects. */ + if (_may_have_objects(PyArray_DESCR(self)) || _may_have_objects(newtype)) { + npy_cache_import("numpy.core._internal", "_view_is_safe", &checkfunc); + if (checkfunc == NULL) { + return -1; + } + + safe = PyObject_CallFunction(checkfunc, "OO", + PyArray_DESCR(self), newtype); + if (safe == NULL) { + Py_DECREF(newtype); + return -1; + } + Py_DECREF(safe); } - Py_DECREF(safe); if (newtype->elsize == 0) { /* Allow a void view */ diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c index 42a12db14..44de1cbf2 100644 --- a/numpy/core/src/multiarray/mapping.c +++ b/numpy/core/src/multiarray/mapping.c @@ -1250,51 +1250,190 @@ array_subscript_asarray(PyArrayObject *self, PyObject *op) return PyArray_EnsureAnyArray(array_subscript(self, op)); } +/* + * Attempts to subscript an array using a field name or list of field names. + * + * If an error occurred, return 0 and set view to NULL. If the subscript is not + * a string or list of strings, return -1 and set view to NULL. Otherwise + * return 0 and set view to point to a new view into arr for the given fields. + */ NPY_NO_EXPORT int -obj_is_string_or_stringlist(PyObject *op) +_get_field_view(PyArrayObject *arr, PyObject *ind, PyArrayObject **view) { + *view = NULL; + + /* first check for a single field name */ #if defined(NPY_PY3K) - if (PyUnicode_Check(op)) { + if (PyUnicode_Check(ind)) { #else - if (PyString_Check(op) || PyUnicode_Check(op)) { + if (PyString_Check(ind) || PyUnicode_Check(ind)) { #endif - return 1; + PyObject *tup; + PyArray_Descr *fieldtype; + npy_intp offset; + + /* get the field offset and dtype */ + tup = PyDict_GetItem(PyArray_DESCR(arr)->fields, ind); + if (tup == NULL){ + PyObject *errmsg = PyUString_FromString("no field of name "); + PyUString_Concat(&errmsg, ind); + PyErr_SetObject(PyExc_ValueError, errmsg); + Py_DECREF(errmsg); + return 0; + } + if (_unpack_field(tup, &fieldtype, &offset) < 0) { + return 0; + } + + /* view the array at the new offset+dtype */ + Py_INCREF(fieldtype); + *view = (PyArrayObject*)PyArray_NewFromDescr( + Py_TYPE(arr), + fieldtype, + PyArray_NDIM(arr), + PyArray_SHAPE(arr), + PyArray_STRIDES(arr), + PyArray_DATA(arr) + offset, + PyArray_FLAGS(arr), + (PyObject *)arr); + if (*view == NULL) { + return 0; + } + Py_INCREF(arr); + if (PyArray_SetBaseObject(*view, (PyObject *)arr) < 0) { + Py_DECREF(*view); + *view = NULL; + } + return 0; } - else if (PySequence_Check(op) && !PyTuple_Check(op)) { + /* next check for a list of field names */ + else if (PySequence_Check(ind) && !PyTuple_Check(ind)) { int seqlen, i; - PyObject *obj = NULL; - seqlen = PySequence_Size(op); + PyObject *name = NULL, *tup; + PyObject *fields, *names; + PyArray_Descr *view_dtype; + + /* variables needed to make a copy, to remove in the future */ + static PyObject *copyfunc = NULL; + PyObject *viewcopy; + + seqlen = PySequence_Size(ind); - /* quit if we come across a 0-d array (seqlen==-1) or a 0-len array */ + /* quit if have a 0-d array (seqlen==-1) or a 0-len array */ if (seqlen == -1) { PyErr_Clear(); - return 0; + return -1; } if (seqlen == 0) { + return -1; + } + + fields = PyDict_New(); + if (fields == NULL) { + return 0; + } + names = PyTuple_New(seqlen); + if (names == NULL) { + Py_DECREF(fields); return 0; } for (i = 0; i < seqlen; i++) { - obj = PySequence_GetItem(op, i); - if (obj == NULL) { - /* only happens for strange sequence objects. Silently fail */ + name = PySequence_GetItem(ind, i); + if (name == NULL) { + /* only happens for strange sequence objects */ PyErr_Clear(); - return 0; + Py_DECREF(fields); + Py_DECREF(names); + return -1; } #if defined(NPY_PY3K) - if (!PyUnicode_Check(obj)) { + if (!PyUnicode_Check(name)) { #else - if (!PyString_Check(obj) && !PyUnicode_Check(obj)) { + if (!PyString_Check(name) && !PyUnicode_Check(name)) { #endif - Py_DECREF(obj); + Py_DECREF(name); + Py_DECREF(fields); + Py_DECREF(names); + return -1; + } + + tup = PyDict_GetItem(PyArray_DESCR(arr)->fields, name); + if (tup == NULL){ + PyObject *errmsg = PyUString_FromString("no field of name "); + PyUString_ConcatAndDel(&errmsg, name); + PyErr_SetObject(PyExc_ValueError, errmsg); + Py_DECREF(errmsg); + Py_DECREF(fields); + Py_DECREF(names); return 0; } - Py_DECREF(obj); + if (PyDict_SetItem(fields, name, tup) < 0) { + Py_DECREF(name); + Py_DECREF(fields); + Py_DECREF(names); + return 0; + } + if (PyTuple_SetItem(names, i, name) < 0) { + Py_DECREF(fields); + Py_DECREF(names); + return 0; + } + } + + view_dtype = PyArray_DescrNewFromType(NPY_VOID); + if (view_dtype == NULL) { + Py_DECREF(fields); + Py_DECREF(names); + return 0; + } + view_dtype->elsize = PyArray_DESCR(arr)->elsize; + view_dtype->names = names; + view_dtype->fields = fields; + view_dtype->flags = PyArray_DESCR(arr)->flags; + + *view = (PyArrayObject*)PyArray_NewFromDescr( + Py_TYPE(arr), + view_dtype, + PyArray_NDIM(arr), + PyArray_SHAPE(arr), + PyArray_STRIDES(arr), + PyArray_DATA(arr), + PyArray_FLAGS(arr), + (PyObject *)arr); + if (*view == NULL) { + return 0; + } + Py_INCREF(arr); + if (PyArray_SetBaseObject(*view, (PyObject *)arr) < 0) { + Py_DECREF(*view); + *view = NULL; + return 0; + } + + /* + * Return copy for now (future plan to return the view above). All the + * following code in this block can then be replaced by "return 0;" + */ + npy_cache_import("numpy.core._internal", "_copy_fields", ©func); + if (copyfunc == NULL) { + Py_DECREF(*view); + *view = NULL; + return 0; + } + + viewcopy = PyObject_CallFunction(copyfunc, "O", *view); + if (viewcopy == NULL) { + Py_DECREF(*view); + *view = NULL; + return 0; } - return 1; + Py_DECREF(*view); + *view = (PyArrayObject*)viewcopy; + return 0; } - return 0; + return -1; } /* @@ -1318,25 +1457,20 @@ array_subscript(PyArrayObject *self, PyObject *op) PyArrayMapIterObject * mit = NULL; /* return fields if op is a string index */ - if (PyDataType_HASFIELDS(PyArray_DESCR(self)) && - obj_is_string_or_stringlist(op)) { - PyObject *obj; - static PyObject *indexfunc = NULL; - npy_cache_import("numpy.core._internal", "_index_fields", &indexfunc); - if (indexfunc == NULL) { - return NULL; - } - - obj = PyObject_CallFunction(indexfunc, "OO", self, op); - if (obj == NULL) { - return NULL; - } + if (PyDataType_HASFIELDS(PyArray_DESCR(self))) { + PyArrayObject *view; + int ret = _get_field_view(self, op, &view); + if (ret == 0){ + if (view == NULL) { + return NULL; + } - /* warn if writing to a copy. copies will have no base */ - if (PyArray_BASE((PyArrayObject*)obj) == NULL) { - PyArray_ENABLEFLAGS((PyArrayObject*)obj, NPY_ARRAY_WARN_ON_WRITE); + /* warn if writing to a copy. copies will have no base */ + if (PyArray_BASE(view) == NULL) { + PyArray_ENABLEFLAGS(view, NPY_ARRAY_WARN_ON_WRITE); + } + return (PyObject*)view; } - return obj; } /* Prepare the indices */ @@ -1671,37 +1805,31 @@ array_assign_subscript(PyArrayObject *self, PyObject *ind, PyObject *op) } /* field access */ - if (PyDataType_HASFIELDS(PyArray_DESCR(self)) && - obj_is_string_or_stringlist(ind)) { - PyObject *obj; - static PyObject *indexfunc = NULL; + if (PyDataType_HASFIELDS(PyArray_DESCR(self))){ + PyArrayObject *view; + int ret = _get_field_view(self, ind, &view); + if (ret == 0){ #if defined(NPY_PY3K) - if (!PyUnicode_Check(ind)) { + if (!PyUnicode_Check(ind)) { #else - if (!PyString_Check(ind) && !PyUnicode_Check(ind)) { + if (!PyString_Check(ind) && !PyUnicode_Check(ind)) { #endif - PyErr_SetString(PyExc_ValueError, - "multi-field assignment is not supported"); - } - - npy_cache_import("numpy.core._internal", "_index_fields", &indexfunc); - if (indexfunc == NULL) { - return -1; - } - - obj = PyObject_CallFunction(indexfunc, "OO", self, ind); - if (obj == NULL) { - return -1; - } + PyErr_SetString(PyExc_ValueError, + "multi-field assignment is not supported"); + return -1; + } - if (PyArray_CopyObject((PyArrayObject*)obj, op) < 0) { - Py_DECREF(obj); - return -1; + if (view == NULL) { + return -1; + } + if (PyArray_CopyObject(view, op) < 0) { + Py_DECREF(view); + return -1; + } + Py_DECREF(view); + return 0; } - Py_DECREF(obj); - - return 0; } /* Prepare the indices */ diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index fd329cb8c..84d4e2c9e 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -362,19 +362,22 @@ PyArray_GetField(PyArrayObject *self, PyArray_Descr *typed, int offset) PyObject *safe; static PyObject *checkfunc = NULL; - npy_cache_import("numpy.core._internal", "_getfield_is_safe", &checkfunc); - if (checkfunc == NULL) { - return NULL; - } + /* check that we are not reinterpreting memory containing Objects. */ + if (_may_have_objects(PyArray_DESCR(self)) || _may_have_objects(typed)) { + npy_cache_import("numpy.core._internal", "_getfield_is_safe", + &checkfunc); + if (checkfunc == NULL) { + return NULL; + } - /* check that we are not reinterpreting memory containing Objects */ - /* only returns True or raises */ - safe = PyObject_CallFunction(checkfunc, "OOi", PyArray_DESCR(self), - typed, offset); - if (safe == NULL) { - return NULL; + /* only returns True or raises */ + safe = PyObject_CallFunction(checkfunc, "OOi", PyArray_DESCR(self), + typed, offset); + if (safe == NULL) { + return NULL; + } + Py_DECREF(safe); } - Py_DECREF(safe); ret = PyArray_NewFromDescr(Py_TYPE(self), typed, diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index ee5741ae0..1bd5b22d2 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1681,13 +1681,15 @@ voidtype_setfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds) * b['x'][0] = arange(3) # uses ndarray setitem * * Ndarray's setfield would try to broadcast the lhs. Instead we use - * ndarray getfield to get the field safely, then setitem to set the value - * without broadcast. Note we also want subarrays to be set properly, ie + * ndarray getfield to get the field safely, then setitem with an empty + * tuple to set the value without broadcast. Note we also want subarrays to + * be set properly, ie * * a = np.zeros(1, dtype=[('x', 'i', 5)]) * a[0]['x'] = 1 * - * sets all values to 1. Setitem does this. + * sets all values to 1. "getfield + setitem with empty tuple" takes + * care of both object arrays and subarrays. */ PyObject *getfield_args, *value, *arr, *meth, *arr_field, *emptytuple; @@ -1726,15 +1728,15 @@ voidtype_setfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds) return NULL; } - /* 2. Fill the resulting array using setitem */ + /* 2. Assign the value using setitem with empty tuple. */ emptytuple = PyTuple_New(0); if (PyObject_SetItem(arr_field, emptytuple, value) < 0) { Py_DECREF(arr_field); Py_DECREF(emptytuple); return NULL; } - Py_DECREF(arr_field); Py_DECREF(emptytuple); + Py_DECREF(arr_field); Py_RETURN_NONE; } @@ -2158,10 +2160,13 @@ voidtype_length(PyVoidScalarObject *self) } static PyObject * +voidtype_subscript(PyVoidScalarObject *self, PyObject *ind); + +static PyObject * voidtype_item(PyVoidScalarObject *self, Py_ssize_t n) { npy_intp m; - PyObject *flist=NULL, *fieldind, *fieldparam, *fieldinfo, *ret; + PyObject *flist=NULL; if (!(PyDataType_HASFIELDS(self->descr))) { PyErr_SetString(PyExc_IndexError, @@ -2177,22 +2182,16 @@ voidtype_item(PyVoidScalarObject *self, Py_ssize_t n) PyErr_Format(PyExc_IndexError, "invalid index (%d)", (int) n); return NULL; } - /* no error checking needed: descr->names is well structured */ - fieldind = PyTuple_GET_ITEM(flist, n); - fieldparam = PyDict_GetItem(self->descr->fields, fieldind); - fieldinfo = PyTuple_GetSlice(fieldparam, 0, 2); - ret = voidtype_getfield(self, fieldinfo, NULL); - Py_DECREF(fieldinfo); - return ret; -} + return voidtype_subscript(self, PyTuple_GetItem(flist, n)); +} /* get field by name or number */ static PyObject * voidtype_subscript(PyVoidScalarObject *self, PyObject *ind) { npy_intp n; - PyObject *ret, *fieldinfo, *fieldparam; + PyObject *ret, *args; if (!(PyDataType_HASFIELDS(self->descr))) { PyErr_SetString(PyExc_IndexError, @@ -2205,14 +2204,9 @@ voidtype_subscript(PyVoidScalarObject *self, PyObject *ind) #else if (PyBytes_Check(ind) || PyUnicode_Check(ind)) { #endif - /* look up in fields */ - fieldparam = PyDict_GetItem(self->descr->fields, ind); - if (!fieldparam) { - goto fail; - } - fieldinfo = PyTuple_GetSlice(fieldparam, 0, 2); - ret = voidtype_getfield(self, fieldinfo, NULL); - Py_DECREF(fieldinfo); + args = Py_BuildValue("(O)", ind); + ret = gentype_generic_method((PyObject *)self, args, NULL, "__getitem__"); + Py_DECREF(args); return ret; } @@ -2229,11 +2223,13 @@ fail: } static int +voidtype_ass_subscript(PyVoidScalarObject *self, PyObject *ind, PyObject *val); + +static int voidtype_ass_item(PyVoidScalarObject *self, Py_ssize_t n, PyObject *val) { npy_intp m; - PyObject *flist=NULL, *fieldinfo, *newtup; - PyObject *res; + PyObject *flist=NULL; if (!(PyDataType_HASFIELDS(self->descr))) { PyErr_SetString(PyExc_IndexError, @@ -2247,24 +2243,11 @@ voidtype_ass_item(PyVoidScalarObject *self, Py_ssize_t n, PyObject *val) n += m; } if (n < 0 || n >= m) { - goto fail; - } - fieldinfo = PyDict_GetItem(self->descr->fields, - PyTuple_GET_ITEM(flist, n)); - newtup = Py_BuildValue("(OOO)", val, - PyTuple_GET_ITEM(fieldinfo, 0), - PyTuple_GET_ITEM(fieldinfo, 1)); - res = voidtype_setfield(self, newtup, NULL); - Py_DECREF(newtup); - if (!res) { + PyErr_Format(PyExc_IndexError, "invalid index (%d)", (int) n); return -1; } - Py_DECREF(res); - return 0; -fail: - PyErr_Format(PyExc_IndexError, "invalid index (%d)", (int) n); - return -1; + return voidtype_ass_subscript(self, PyTuple_GetItem(flist, n), val); } static int @@ -2272,8 +2255,7 @@ voidtype_ass_subscript(PyVoidScalarObject *self, PyObject *ind, PyObject *val) { npy_intp n; char *msg = "invalid index"; - PyObject *fieldinfo, *newtup; - PyObject *res; + PyObject *args; if (!PyDataType_HASFIELDS(self->descr)) { PyErr_SetString(PyExc_IndexError, @@ -2292,20 +2274,49 @@ voidtype_ass_subscript(PyVoidScalarObject *self, PyObject *ind, PyObject *val) #else if (PyBytes_Check(ind) || PyUnicode_Check(ind)) { #endif - /* look up in fields */ - fieldinfo = PyDict_GetItem(self->descr->fields, ind); - if (!fieldinfo) { - goto fail; + /* + * Much like in voidtype_setfield, we cannot simply use ndarray's + * __setitem__ since assignment to void scalars should not broadcast + * the lhs. Instead we get a view through __getitem__ and then assign + * the value using setitem with an empty tuple (which treats both + * object arrays and subarrays properly). + * + * Also we do not want to use voidtype_setfield here, since we do + * not need to do the (slow) view safety checks, since we already + * know the dtype/offset are safe. + */ + + PyObject *arr, *arr_field, *meth, *emptytuple; + + /* 1. Convert to 0-d array and use getitem */ + arr = PyArray_FromScalar((PyObject*)self, NULL); + if (arr == NULL) { + return -1; + } + meth = PyObject_GetAttrString(arr, "__getitem__"); + if (meth == NULL) { + Py_DECREF(arr); + return -1; } - newtup = Py_BuildValue("(OOO)", val, - PyTuple_GET_ITEM(fieldinfo, 0), - PyTuple_GET_ITEM(fieldinfo, 1)); - res = voidtype_setfield(self, newtup, NULL); - Py_DECREF(newtup); - if (!res) { + args = Py_BuildValue("(O)", ind); + arr_field = PyObject_CallObject(meth, args); + Py_DECREF(meth); + Py_DECREF(arr); + Py_DECREF(args); + + if(arr_field == NULL){ return -1; } - Py_DECREF(res); + + /* 2. Assign the value using setitem with empty tuple. */ + emptytuple = PyTuple_New(0); + if (PyObject_SetItem(arr_field, emptytuple, val) < 0) { + Py_DECREF(arr_field); + Py_DECREF(emptytuple); + return -1; + } + Py_DECREF(emptytuple); + Py_DECREF(arr_field); return 0; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index d47b9f0da..85b0e5519 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -3754,8 +3754,8 @@ class TestRecord(TestCase): b[0][fn1] = 2 assert_equal(b[fn1], 2) # Subfield - assert_raises(IndexError, b[0].__setitem__, fnn, 1) - assert_raises(IndexError, b[0].__getitem__, fnn) + assert_raises(ValueError, b[0].__setitem__, fnn, 1) + assert_raises(ValueError, b[0].__getitem__, fnn) # Subfield fn3 = func('f3') sfn1 = func('sf1') |