summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulian Taylor <juliantaylor108@gmail.com>2014-07-23 01:16:58 +0200
committerJulian Taylor <juliantaylor108@gmail.com>2014-07-23 01:16:58 +0200
commit809938d8dd481f789a3295a2cccce5aa521a6344 (patch)
treef2c6c55afa3617fb70a6dca0ee37d6cbb7513fe3
parenta28bfa5780c6611c9c46977a0e9c2123426cf24a (diff)
parent16f39c80f69ae4695f2f940ccdff8b26db3a2a01 (diff)
downloadnumpy-809938d8dd481f789a3295a2cccce5aa521a6344.tar.gz
Merge pull request #4888 from pv/fix-bytes-encoding-unpickle
ENH: core: make unpickling with encoding='bytes' work
-rw-r--r--numpy/core/src/multiarray/descriptor.c181
-rw-r--r--numpy/core/tests/test_regression.py35
2 files changed, 166 insertions, 50 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index 8b55c9fbd..238077b36 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -2369,11 +2369,8 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
{
int elsize = -1, alignment = -1;
int version = 4;
-#if defined(NPY_PY3K)
- int endian;
-#else
char endian;
-#endif
+ PyObject *endian_obj;
PyObject *subarray, *fields, *names = NULL, *metadata=NULL;
int incref_names = 1;
int int_dtypeflags = 0;
@@ -2390,68 +2387,39 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
}
switch (PyTuple_GET_SIZE(PyTuple_GET_ITEM(args,0))) {
case 9:
-#if defined(NPY_PY3K)
-#define _ARGSTR_ "(iCOOOiiiO)"
-#else
-#define _ARGSTR_ "(icOOOiiiO)"
-#endif
- if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
+ if (!PyArg_ParseTuple(args, "(iOOOOiiiO)", &version, &endian_obj,
&subarray, &names, &fields, &elsize,
&alignment, &int_dtypeflags, &metadata)) {
+ PyErr_Clear();
return NULL;
-#undef _ARGSTR_
}
break;
case 8:
-#if defined(NPY_PY3K)
-#define _ARGSTR_ "(iCOOOiii)"
-#else
-#define _ARGSTR_ "(icOOOiii)"
-#endif
- if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
+ if (!PyArg_ParseTuple(args, "(iOOOOiii)", &version, &endian_obj,
&subarray, &names, &fields, &elsize,
&alignment, &int_dtypeflags)) {
return NULL;
-#undef _ARGSTR_
}
break;
case 7:
-#if defined(NPY_PY3K)
-#define _ARGSTR_ "(iCOOOii)"
-#else
-#define _ARGSTR_ "(icOOOii)"
-#endif
- if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
+ if (!PyArg_ParseTuple(args, "(iOOOOii)", &version, &endian_obj,
&subarray, &names, &fields, &elsize,
&alignment)) {
return NULL;
-#undef _ARGSTR_
}
break;
case 6:
-#if defined(NPY_PY3K)
-#define _ARGSTR_ "(iCOOii)"
-#else
-#define _ARGSTR_ "(icOOii)"
-#endif
- if (!PyArg_ParseTuple(args, _ARGSTR_, &version,
- &endian, &subarray, &fields,
+ if (!PyArg_ParseTuple(args, "(iOOOii)", &version,
+ &endian_obj, &subarray, &fields,
&elsize, &alignment)) {
- PyErr_Clear();
-#undef _ARGSTR_
+ return NULL;
}
break;
case 5:
version = 0;
-#if defined(NPY_PY3K)
-#define _ARGSTR_ "(COOii)"
-#else
-#define _ARGSTR_ "(cOOii)"
-#endif
- if (!PyArg_ParseTuple(args, _ARGSTR_,
- &endian, &subarray, &fields, &elsize,
+ if (!PyArg_ParseTuple(args, "(OOOii)",
+ &endian_obj, &subarray, &fields, &elsize,
&alignment)) {
-#undef _ARGSTR_
return NULL;
}
break;
@@ -2494,11 +2462,55 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
}
}
+ /* Parse endian */
+ if (PyUnicode_Check(endian_obj) || PyBytes_Check(endian_obj)) {
+ PyObject *tmp = NULL;
+ char *str;
+ Py_ssize_t len;
+
+ if (PyUnicode_Check(endian_obj)) {
+ tmp = PyUnicode_AsASCIIString(endian_obj);
+ if (tmp == NULL) {
+ return NULL;
+ }
+ endian_obj = tmp;
+ }
+
+ if (PyBytes_AsStringAndSize(endian_obj, &str, &len) == -1) {
+ Py_XDECREF(tmp);
+ return NULL;
+ }
+ if (len != 1) {
+ PyErr_SetString(PyExc_ValueError,
+ "endian is not 1-char string in Numpy dtype unpickling");
+ Py_XDECREF(tmp);
+ return NULL;
+ }
+ endian = str[0];
+ Py_XDECREF(tmp);
+ }
+ else {
+ PyErr_SetString(PyExc_ValueError,
+ "endian is not a string in Numpy dtype unpickling");
+ return NULL;
+ }
if ((fields == Py_None && names != Py_None) ||
(names == Py_None && fields != Py_None)) {
PyErr_Format(PyExc_ValueError,
- "inconsistent fields and names");
+ "inconsistent fields and names in Numpy dtype unpickling");
+ return NULL;
+ }
+
+ if (names != Py_None && !PyTuple_Check(names)) {
+ PyErr_Format(PyExc_ValueError,
+ "non-tuple names in Numpy dtype unpickling");
+ return NULL;
+ }
+
+ if (fields != Py_None && !PyDict_Check(fields)) {
+ PyErr_Format(PyExc_ValueError,
+ "non-dict fields in Numpy dtype unpickling");
return NULL;
}
@@ -2563,13 +2575,82 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
}
if (fields != Py_None) {
- Py_XDECREF(self->fields);
- self->fields = fields;
- Py_INCREF(fields);
- Py_XDECREF(self->names);
- self->names = names;
- if (incref_names) {
- Py_INCREF(names);
+ /*
+ * Ensure names are of appropriate string type
+ */
+ Py_ssize_t i;
+ int names_ok = 1;
+ PyObject *name;
+
+ for (i = 0; i < PyTuple_GET_SIZE(names); ++i) {
+ name = PyTuple_GET_ITEM(names, i);
+ if (!PyUString_Check(name)) {
+ names_ok = 0;
+ break;
+ }
+ }
+
+ if (names_ok) {
+ Py_XDECREF(self->fields);
+ self->fields = fields;
+ Py_INCREF(fields);
+ Py_XDECREF(self->names);
+ self->names = names;
+ if (incref_names) {
+ Py_INCREF(names);
+ }
+ }
+ else {
+#if defined(NPY_PY3K)
+ /*
+ * To support pickle.load(f, encoding='bytes') for loading Py2
+ * generated pickles on Py3, we need to be more lenient and convert
+ * field names from byte strings to unicode.
+ */
+ PyObject *tmp, *new_name, *field;
+
+ tmp = PyDict_New();
+ if (tmp == NULL) {
+ return NULL;
+ }
+ Py_XDECREF(self->fields);
+ self->fields = tmp;
+
+ tmp = PyTuple_New(PyTuple_GET_SIZE(names));
+ if (tmp == NULL) {
+ return NULL;
+ }
+ Py_XDECREF(self->names);
+ self->names = tmp;
+
+ for (i = 0; i < PyTuple_GET_SIZE(names); ++i) {
+ name = PyTuple_GET_ITEM(names, i);
+ field = PyDict_GetItem(fields, name);
+ if (!field) {
+ return NULL;
+ }
+
+ if (PyUnicode_Check(name)) {
+ new_name = name;
+ Py_INCREF(new_name);
+ }
+ else {
+ new_name = PyUnicode_FromEncodedObject(name, "ASCII", "strict");
+ if (new_name == NULL) {
+ return NULL;
+ }
+ }
+
+ PyTuple_SET_ITEM(self->names, i, new_name);
+ if (PyDict_SetItem(self->fields, new_name, field) != 0) {
+ return NULL;
+ }
+ }
+#else
+ PyErr_Format(PyExc_ValueError,
+ "non-string names in Numpy dtype unpickling");
+ return NULL;
+#endif
}
}
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 5d36c9b1b..c7eaad984 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -398,6 +398,41 @@ class TestRegression(TestCase):
assert_raises(KeyError, np.lexsort, BuggySequence())
+ def test_pickle_py2_bytes_encoding(self):
+ # Check that arrays and scalars pickled on Py2 are
+ # unpickleable on Py3 using encoding='bytes'
+
+ test_data = [
+ # (original, py2_pickle)
+ (np.unicode_('\u6f2c'),
+ asbytes("cnumpy.core.multiarray\nscalar\np0\n(cnumpy\ndtype\np1\n"
+ "(S'U1'\np2\nI0\nI1\ntp3\nRp4\n(I3\nS'<'\np5\nNNNI4\nI4\n"
+ "I0\ntp6\nbS',o\\x00\\x00'\np7\ntp8\nRp9\n.")),
+
+ (np.array([9e123], dtype=np.float64),
+ asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\n"
+ "p1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\n"
+ "p7\n(S'f8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\n"
+ "I0\ntp12\nbI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np13\ntp14\nb.")),
+
+ (np.array([(9e123,)], dtype=[('name', float)]),
+ asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n"
+ "(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\np7\n"
+ "(S'V8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'|'\np11\nN(S'name'\np12\ntp13\n"
+ "(dp14\ng12\n(g7\n(S'f8'\np15\nI0\nI1\ntp16\nRp17\n(I3\nS'<'\np18\nNNNI-1\n"
+ "I-1\nI0\ntp19\nbI0\ntp20\nsI8\nI1\nI0\ntp21\n"
+ "bI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np22\ntp23\nb.")),
+ ]
+
+ if sys.version_info[:2] >= (3, 4):
+ # encoding='bytes' was added in Py3.4
+ for original, data in test_data:
+ result = pickle.loads(data, encoding='bytes')
+ assert_equal(result, original)
+
+ if isinstance(result, np.ndarray) and result.dtype.names:
+ for name in result.dtype.names:
+ assert_(isinstance(name, str))
def test_pickle_dtype(self,level=rlevel):
"""Ticket #251"""