summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/descriptor.c97
-rw-r--r--numpy/core/tests/test_regression.py12
2 files changed, 101 insertions, 8 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index 3d93e801a..238077b36 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -2498,7 +2498,19 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
if ((fields == Py_None && names != Py_None) ||
(names == Py_None && fields != Py_None)) {
PyErr_Format(PyExc_ValueError,
- "inconsistent fields and names");
+ "inconsistent fields and names in Numpy dtype unpickling");
+ return NULL;
+ }
+
+ if (names != Py_None && !PyTuple_Check(names)) {
+ PyErr_Format(PyExc_ValueError,
+ "non-tuple names in Numpy dtype unpickling");
+ return NULL;
+ }
+
+ if (fields != Py_None && !PyDict_Check(fields)) {
+ PyErr_Format(PyExc_ValueError,
+ "non-dict fields in Numpy dtype unpickling");
return NULL;
}
@@ -2563,13 +2575,82 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
}
if (fields != Py_None) {
- Py_XDECREF(self->fields);
- self->fields = fields;
- Py_INCREF(fields);
- Py_XDECREF(self->names);
- self->names = names;
- if (incref_names) {
- Py_INCREF(names);
+ /*
+ * Ensure names are of appropriate string type
+ */
+ Py_ssize_t i;
+ int names_ok = 1;
+ PyObject *name;
+
+ for (i = 0; i < PyTuple_GET_SIZE(names); ++i) {
+ name = PyTuple_GET_ITEM(names, i);
+ if (!PyUString_Check(name)) {
+ names_ok = 0;
+ break;
+ }
+ }
+
+ if (names_ok) {
+ Py_XDECREF(self->fields);
+ self->fields = fields;
+ Py_INCREF(fields);
+ Py_XDECREF(self->names);
+ self->names = names;
+ if (incref_names) {
+ Py_INCREF(names);
+ }
+ }
+ else {
+#if defined(NPY_PY3K)
+ /*
+ * To support pickle.load(f, encoding='bytes') for loading Py2
+ * generated pickles on Py3, we need to be more lenient and convert
+ * field names from byte strings to unicode.
+ */
+ PyObject *tmp, *new_name, *field;
+
+ tmp = PyDict_New();
+ if (tmp == NULL) {
+ return NULL;
+ }
+ Py_XDECREF(self->fields);
+ self->fields = tmp;
+
+ tmp = PyTuple_New(PyTuple_GET_SIZE(names));
+ if (tmp == NULL) {
+ return NULL;
+ }
+ Py_XDECREF(self->names);
+ self->names = tmp;
+
+ for (i = 0; i < PyTuple_GET_SIZE(names); ++i) {
+ name = PyTuple_GET_ITEM(names, i);
+ field = PyDict_GetItem(fields, name);
+ if (!field) {
+ return NULL;
+ }
+
+ if (PyUnicode_Check(name)) {
+ new_name = name;
+ Py_INCREF(new_name);
+ }
+ else {
+ new_name = PyUnicode_FromEncodedObject(name, "ASCII", "strict");
+ if (new_name == NULL) {
+ return NULL;
+ }
+ }
+
+ PyTuple_SET_ITEM(self->names, i, new_name);
+ if (PyDict_SetItem(self->fields, new_name, field) != 0) {
+ return NULL;
+ }
+ }
+#else
+ PyErr_Format(PyExc_ValueError,
+ "non-string names in Numpy dtype unpickling");
+ return NULL;
+#endif
}
}
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index a83713a75..3bf1a28f9 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -414,6 +414,14 @@ class TestRegression(TestCase):
"p1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\n"
"p7\n(S'f8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\n"
"I0\ntp12\nbI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np13\ntp14\nb.")),
+
+ (np.array([(9e123,)], dtype=[('name', float)]),
+ asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n"
+ "(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\np7\n"
+ "(S'V8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'|'\np11\nN(S'name'\np12\ntp13\n"
+ "(dp14\ng12\n(g7\n(S'f8'\np15\nI0\nI1\ntp16\nRp17\n(I3\nS'<'\np18\nNNNI-1\n"
+ "I-1\nI0\ntp19\nbI0\ntp20\nsI8\nI1\nI0\ntp21\n"
+ "bI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np22\ntp23\nb.")),
]
if sys.version_info[:2] >= (3, 4):
@@ -422,6 +430,10 @@ class TestRegression(TestCase):
result = pickle.loads(data, encoding='bytes')
assert_equal(result, original)
+ if isinstance(result, np.ndarray) and result.dtype.names:
+ for name in result.dtype.names:
+ assert_(isinstance(name, str))
+
def test_pickle_dtype(self,level=rlevel):
"""Ticket #251"""
pickle.dumps(np.float)