summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/_internal.py19
-rw-r--r--numpy/core/src/arrayobject.c30
-rw-r--r--numpy/core/tests/test_numerictypes.py10
3 files changed, 58 insertions, 1 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 558d2fe93..7d5c3a49e 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -292,3 +292,22 @@ def _newnames(datatype, order):
raise ValueError, "unknown field name: %s" % (name,)
return tuple(list(order) + nameslist)
raise ValueError, "unsupported order value: %s" % (order,)
+
+# Given an array with fields and a sequence of field names
+# construct a new array with just those fields copied over
+def _index_fields(ary, fields):
+ from multiarray import empty, dtype
+ dt = ary.dtype
+ new_dtype = [(name, dt[name]) for name in dt.names if name in fields]
+ if ary.flags.f_contiguous:
+ order = 'F'
+ else:
+ order = 'C'
+
+ newarray = empty(ary.shape, dtype=new_dtype, order=order)
+
+ for name in fields:
+ newarray[name] = ary[name]
+
+ return newarray
+
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 32d49eaf2..6c8a64d84 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -2827,10 +2827,10 @@ array_subscript(PyArrayObject *self, PyObject *op)
int nd, fancy;
PyArrayObject *other;
PyArrayMapIterObject *mit;
+ PyObject *obj;
if (PyString_Check(op) || PyUnicode_Check(op)) {
if (self->descr->names) {
- PyObject *obj;
obj = PyDict_GetItem(self->descr->fields, op);
if (obj != NULL) {
PyArray_Descr *descr;
@@ -2852,6 +2852,34 @@ array_subscript(PyArrayObject *self, PyObject *op)
return NULL;
}
+ /* Check for multiple field access
+ */
+ if (self->descr->names && PySequence_Check(op) && !PyTuple_Check(op)) {
+ int seqlen, i;
+ seqlen = PySequence_Size(op);
+ for (i=0; i<seqlen; i++) {
+ obj = PySequence_GetItem(op, i);
+ if (!PyString_Check(obj) && !PyUnicode_Check(obj)) {
+ Py_DECREF(obj);
+ break;
+ }
+ Py_DECREF(obj);
+ }
+ /* extract multiple fields if all elements in sequence
+ are either string or unicode (i.e. no break occurred).
+ */
+ fancy = ((seqlen > 0) && (i == seqlen));
+ if (fancy) {
+ PyObject *_numpy_internal;
+ _numpy_internal = PyImport_ImportModule("numpy.core._internal");
+ if (_numpy_internal == NULL) return NULL;
+ obj = PyObject_CallMethod(_numpy_internal, "_index_fields",
+ "OO", self, op);
+ Py_DECREF(_numpy_internal);
+ return obj;
+ }
+ }
+
if (op == Py_Ellipsis) {
Py_INCREF(self);
return (PyObject *)self;
diff --git a/numpy/core/tests/test_numerictypes.py b/numpy/core/tests/test_numerictypes.py
index 7f857065c..4e0bb462b 100644
--- a/numpy/core/tests/test_numerictypes.py
+++ b/numpy/core/tests/test_numerictypes.py
@@ -353,6 +353,16 @@ class TestCommonType(TestCase):
res = np.find_common_type(['u8','i8','i8'],['f8'])
assert(res == 'f8')
+class TestMultipleFields(TestCase):
+ def setUp(self):
+ self.ary = np.array([(1,2,3,4),(5,6,7,8)], dtype='i4,f4,i2,c8')
+ def _bad_call(self):
+ return self.ary['f0','f1']
+ def test_no_tuple(self):
+ self.failUnlessRaises(ValueError, self._bad_call)
+ def test_return(self):
+ res = self.ary[['f0','f2']].tolist()
+ assert(res == [(1,3), (5,7)])
if __name__ == "__main__":
run_module_suite()