diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/_internal.py | 19 | ||||
-rw-r--r-- | numpy/core/src/arrayobject.c | 30 | ||||
-rw-r--r-- | numpy/core/tests/test_numerictypes.py | 10 |
3 files changed, 58 insertions, 1 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 558d2fe93..7d5c3a49e 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -292,3 +292,22 @@ def _newnames(datatype, order): raise ValueError, "unknown field name: %s" % (name,) return tuple(list(order) + nameslist) raise ValueError, "unsupported order value: %s" % (order,) + +# Given an array with fields and a sequence of field names +# construct a new array with just those fields copied over +def _index_fields(ary, fields): + from multiarray import empty, dtype + dt = ary.dtype + new_dtype = [(name, dt[name]) for name in dt.names if name in fields] + if ary.flags.f_contiguous: + order = 'F' + else: + order = 'C' + + newarray = empty(ary.shape, dtype=new_dtype, order=order) + + for name in fields: + newarray[name] = ary[name] + + return newarray + diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 32d49eaf2..6c8a64d84 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -2827,10 +2827,10 @@ array_subscript(PyArrayObject *self, PyObject *op) int nd, fancy; PyArrayObject *other; PyArrayMapIterObject *mit; + PyObject *obj; if (PyString_Check(op) || PyUnicode_Check(op)) { if (self->descr->names) { - PyObject *obj; obj = PyDict_GetItem(self->descr->fields, op); if (obj != NULL) { PyArray_Descr *descr; @@ -2852,6 +2852,34 @@ array_subscript(PyArrayObject *self, PyObject *op) return NULL; } + /* Check for multiple field access + */ + if (self->descr->names && PySequence_Check(op) && !PyTuple_Check(op)) { + int seqlen, i; + seqlen = PySequence_Size(op); + for (i=0; i<seqlen; i++) { + obj = PySequence_GetItem(op, i); + if (!PyString_Check(obj) && !PyUnicode_Check(obj)) { + Py_DECREF(obj); + break; + } + Py_DECREF(obj); + } + /* extract multiple fields if all elements in sequence + are either string or unicode (i.e. no break occurred). + */ + fancy = ((seqlen > 0) && (i == seqlen)); + if (fancy) { + PyObject *_numpy_internal; + _numpy_internal = PyImport_ImportModule("numpy.core._internal"); + if (_numpy_internal == NULL) return NULL; + obj = PyObject_CallMethod(_numpy_internal, "_index_fields", + "OO", self, op); + Py_DECREF(_numpy_internal); + return obj; + } + } + if (op == Py_Ellipsis) { Py_INCREF(self); return (PyObject *)self; diff --git a/numpy/core/tests/test_numerictypes.py b/numpy/core/tests/test_numerictypes.py index 7f857065c..4e0bb462b 100644 --- a/numpy/core/tests/test_numerictypes.py +++ b/numpy/core/tests/test_numerictypes.py @@ -353,6 +353,16 @@ class TestCommonType(TestCase): res = np.find_common_type(['u8','i8','i8'],['f8']) assert(res == 'f8') +class TestMultipleFields(TestCase): + def setUp(self): + self.ary = np.array([(1,2,3,4),(5,6,7,8)], dtype='i4,f4,i2,c8') + def _bad_call(self): + return self.ary['f0','f1'] + def test_no_tuple(self): + self.failUnlessRaises(ValueError, self._bad_call) + def test_return(self): + res = self.ary[['f0','f2']].tolist() + assert(res == [(1,3), (5,7)]) if __name__ == "__main__": run_module_suite() |