diff options
author | Pauli Virtanen <pav@iki.fi> | 2014-07-19 00:23:58 +0300 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2014-07-22 23:42:03 +0300 |
commit | 16f39c80f69ae4695f2f940ccdff8b26db3a2a01 (patch) | |
tree | dab647108acd39cb4d1bfd7ec21db70451993f62 | |
parent | 4008bb41414184d875af4f733a3ff3d303b94a4a (diff) | |
download | numpy-16f39c80f69ae4695f2f940ccdff8b26db3a2a01.tar.gz |
BUG: core: ensure unpickled dtype fields and names have correct types + coerce on Py3
That 'names' is a tuple and 'fields' a dict (when present) is assumed in
most of the code base, so check them during unpickling.
Also add coercion from bytes using ASCII codec on Python 3. This is
never triggered in the "usual" case of loading Py3-generated pickles on
Py3, but can occur if loading Py2 generated pickles with
pickle.load(f, encoding='bytes'), which sometimes is the only working
option.
The ASCII codec is probably the safest choice and likely covers most use
cases.
-rw-r--r-- | numpy/core/src/multiarray/descriptor.c | 97 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 12 |
2 files changed, 101 insertions, 8 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index 3d93e801a..238077b36 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -2498,7 +2498,19 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args) if ((fields == Py_None && names != Py_None) || (names == Py_None && fields != Py_None)) { PyErr_Format(PyExc_ValueError, - "inconsistent fields and names"); + "inconsistent fields and names in Numpy dtype unpickling"); + return NULL; + } + + if (names != Py_None && !PyTuple_Check(names)) { + PyErr_Format(PyExc_ValueError, + "non-tuple names in Numpy dtype unpickling"); + return NULL; + } + + if (fields != Py_None && !PyDict_Check(fields)) { + PyErr_Format(PyExc_ValueError, + "non-dict fields in Numpy dtype unpickling"); return NULL; } @@ -2563,13 +2575,82 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args) } if (fields != Py_None) { - Py_XDECREF(self->fields); - self->fields = fields; - Py_INCREF(fields); - Py_XDECREF(self->names); - self->names = names; - if (incref_names) { - Py_INCREF(names); + /* + * Ensure names are of appropriate string type + */ + Py_ssize_t i; + int names_ok = 1; + PyObject *name; + + for (i = 0; i < PyTuple_GET_SIZE(names); ++i) { + name = PyTuple_GET_ITEM(names, i); + if (!PyUString_Check(name)) { + names_ok = 0; + break; + } + } + + if (names_ok) { + Py_XDECREF(self->fields); + self->fields = fields; + Py_INCREF(fields); + Py_XDECREF(self->names); + self->names = names; + if (incref_names) { + Py_INCREF(names); + } + } + else { +#if defined(NPY_PY3K) + /* + * To support pickle.load(f, encoding='bytes') for loading Py2 + * generated pickles on Py3, we need to be more lenient and convert + * field names from byte strings to unicode. + */ + PyObject *tmp, *new_name, *field; + + tmp = PyDict_New(); + if (tmp == NULL) { + return NULL; + } + Py_XDECREF(self->fields); + self->fields = tmp; + + tmp = PyTuple_New(PyTuple_GET_SIZE(names)); + if (tmp == NULL) { + return NULL; + } + Py_XDECREF(self->names); + self->names = tmp; + + for (i = 0; i < PyTuple_GET_SIZE(names); ++i) { + name = PyTuple_GET_ITEM(names, i); + field = PyDict_GetItem(fields, name); + if (!field) { + return NULL; + } + + if (PyUnicode_Check(name)) { + new_name = name; + Py_INCREF(new_name); + } + else { + new_name = PyUnicode_FromEncodedObject(name, "ASCII", "strict"); + if (new_name == NULL) { + return NULL; + } + } + + PyTuple_SET_ITEM(self->names, i, new_name); + if (PyDict_SetItem(self->fields, new_name, field) != 0) { + return NULL; + } + } +#else + PyErr_Format(PyExc_ValueError, + "non-string names in Numpy dtype unpickling"); + return NULL; +#endif } } diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index a83713a75..3bf1a28f9 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -414,6 +414,14 @@ class TestRegression(TestCase): "p1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\n" "p7\n(S'f8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\n" "I0\ntp12\nbI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np13\ntp14\nb.")), + + (np.array([(9e123,)], dtype=[('name', float)]), + asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n" + "(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\np7\n" + "(S'V8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'|'\np11\nN(S'name'\np12\ntp13\n" + "(dp14\ng12\n(g7\n(S'f8'\np15\nI0\nI1\ntp16\nRp17\n(I3\nS'<'\np18\nNNNI-1\n" + "I-1\nI0\ntp19\nbI0\ntp20\nsI8\nI1\nI0\ntp21\n" + "bI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np22\ntp23\nb.")), ] if sys.version_info[:2] >= (3, 4): @@ -422,6 +430,10 @@ class TestRegression(TestCase): result = pickle.loads(data, encoding='bytes') assert_equal(result, original) + if isinstance(result, np.ndarray) and result.dtype.names: + for name in result.dtype.names: + assert_(isinstance(name, str)) + def test_pickle_dtype(self,level=rlevel): """Ticket #251""" pickle.dumps(np.float) |