diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/descriptor.c | 97 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 12 |
2 files changed, 101 insertions, 8 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index 3d93e801a..238077b36 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -2498,7 +2498,19 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args) if ((fields == Py_None && names != Py_None) || (names == Py_None && fields != Py_None)) { PyErr_Format(PyExc_ValueError, - "inconsistent fields and names"); + "inconsistent fields and names in Numpy dtype unpickling"); + return NULL; + } + + if (names != Py_None && !PyTuple_Check(names)) { + PyErr_Format(PyExc_ValueError, + "non-tuple names in Numpy dtype unpickling"); + return NULL; + } + + if (fields != Py_None && !PyDict_Check(fields)) { + PyErr_Format(PyExc_ValueError, + "non-dict fields in Numpy dtype unpickling"); return NULL; } @@ -2563,13 +2575,82 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args) } if (fields != Py_None) { - Py_XDECREF(self->fields); - self->fields = fields; - Py_INCREF(fields); - Py_XDECREF(self->names); - self->names = names; - if (incref_names) { - Py_INCREF(names); + /* + * Ensure names are of appropriate string type + */ + Py_ssize_t i; + int names_ok = 1; + PyObject *name; + + for (i = 0; i < PyTuple_GET_SIZE(names); ++i) { + name = PyTuple_GET_ITEM(names, i); + if (!PyUString_Check(name)) { + names_ok = 0; + break; + } + } + + if (names_ok) { + Py_XDECREF(self->fields); + self->fields = fields; + Py_INCREF(fields); + Py_XDECREF(self->names); + self->names = names; + if (incref_names) { + Py_INCREF(names); + } + } + else { +#if defined(NPY_PY3K) + /* + * To support pickle.load(f, encoding='bytes') for loading Py2 + * generated pickles on Py3, we need to be more lenient and convert + * field names from byte strings to unicode. + */ + PyObject *tmp, *new_name, *field; + + tmp = PyDict_New(); + if (tmp == NULL) { + return NULL; + } + Py_XDECREF(self->fields); + self->fields = tmp; + + tmp = PyTuple_New(PyTuple_GET_SIZE(names)); + if (tmp == NULL) { + return NULL; + } + Py_XDECREF(self->names); + self->names = tmp; + + for (i = 0; i < PyTuple_GET_SIZE(names); ++i) { + name = PyTuple_GET_ITEM(names, i); + field = PyDict_GetItem(fields, name); + if (!field) { + return NULL; + } + + if (PyUnicode_Check(name)) { + new_name = name; + Py_INCREF(new_name); + } + else { + new_name = PyUnicode_FromEncodedObject(name, "ASCII", "strict"); + if (new_name == NULL) { + return NULL; + } + } + + PyTuple_SET_ITEM(self->names, i, new_name); + if (PyDict_SetItem(self->fields, new_name, field) != 0) { + return NULL; + } + } +#else + PyErr_Format(PyExc_ValueError, + "non-string names in Numpy dtype unpickling"); + return NULL; +#endif } } diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index a83713a75..3bf1a28f9 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -414,6 +414,14 @@ class TestRegression(TestCase): "p1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\n" "p7\n(S'f8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\n" "I0\ntp12\nbI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np13\ntp14\nb.")), + + (np.array([(9e123,)], dtype=[('name', float)]), + asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n" + "(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\np7\n" + "(S'V8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'|'\np11\nN(S'name'\np12\ntp13\n" + "(dp14\ng12\n(g7\n(S'f8'\np15\nI0\nI1\ntp16\nRp17\n(I3\nS'<'\np18\nNNNI-1\n" + "I-1\nI0\ntp19\nbI0\ntp20\nsI8\nI1\nI0\ntp21\n" + "bI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np22\ntp23\nb.")), ] if sys.version_info[:2] >= (3, 4): @@ -422,6 +430,10 @@ class TestRegression(TestCase): result = pickle.loads(data, encoding='bytes') assert_equal(result, original) + if isinstance(result, np.ndarray) and result.dtype.names: + for name in result.dtype.names: + assert_(isinstance(name, str)) + def test_pickle_dtype(self,level=rlevel): """Ticket #251""" pickle.dumps(np.float) |