summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/release/1.10.0-notes.rst8
-rw-r--r--numpy/lib/format.py28
-rw-r--r--numpy/lib/npyio.py70
-rw-r--r--numpy/lib/tests/data/py2-objarr.npybin0 -> 258 bytes
-rw-r--r--numpy/lib/tests/data/py2-objarr.npzbin0 -> 366 bytes
-rw-r--r--numpy/lib/tests/data/py3-objarr.npybin0 -> 341 bytes
-rw-r--r--numpy/lib/tests/data/py3-objarr.npzbin0 -> 449 bytes
-rw-r--r--numpy/lib/tests/test_format.py67
8 files changed, 158 insertions, 15 deletions
diff --git a/doc/release/1.10.0-notes.rst b/doc/release/1.10.0-notes.rst
index 95cea6537..c38a2ae64 100644
--- a/doc/release/1.10.0-notes.rst
+++ b/doc/release/1.10.0-notes.rst
@@ -178,6 +178,14 @@ compare NaNs as equal by setting ``equal_nan=True``. Subclasses, such as
crashed with an ``OverflowError`` in these cases). Integers larger than
``2**63-1`` are converted to floating-point values.
+*np.load*, *np.save* have pickle backward compatibility flags
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The functions *np.load* and *np.save* have additional keyword
+arguments for controlling backward compatibility of pickled Python
+objects. This enables Numpy on Python 3 to load npy files containing
+object arrays that were generated on Python 2.
+
Changes
=======
diff --git a/numpy/lib/format.py b/numpy/lib/format.py
index 4ff0a660f..1a2133aa9 100644
--- a/numpy/lib/format.py
+++ b/numpy/lib/format.py
@@ -517,7 +517,7 @@ def _read_array_header(fp, version):
return d['shape'], d['fortran_order'], dtype
-def write_array(fp, array, version=None):
+def write_array(fp, array, version=None, pickle_kwargs=None):
"""
Write an array to an NPY file, including a header.
@@ -535,6 +535,10 @@ def write_array(fp, array, version=None):
version : (int, int) or None, optional
The version number of the format. None means use the oldest
supported version that is able to store the data. Default: None
+ pickle_kwargs : dict, optional
+ Additional keyword arguments to pass to pickle.dump, excluding
+ 'protocol'. These are only useful when pickling objects in object
+ arrays on Python 3 to Python 2 compatible format.
Raises
------
@@ -561,7 +565,9 @@ def write_array(fp, array, version=None):
# We contain Python objects so we cannot write out the data
# directly. Instead, we will pickle it out with version 2 of the
# pickle protocol.
- pickle.dump(array, fp, protocol=2)
+ if pickle_kwargs is None:
+ pickle_kwargs = {}
+ pickle.dump(array, fp, protocol=2, **pickle_kwargs)
elif array.flags.f_contiguous and not array.flags.c_contiguous:
if isfileobj(fp):
array.T.tofile(fp)
@@ -580,7 +586,7 @@ def write_array(fp, array, version=None):
fp.write(chunk.tobytes('C'))
-def read_array(fp):
+def read_array(fp, pickle_kwargs=None):
"""
Read an array from an NPY file.
@@ -589,6 +595,10 @@ def read_array(fp):
fp : file_like object
If this is not a real file object, then this may take extra memory
and time.
+ pickle_kwargs : dict
+ Additional keyword arguments to pass to pickle.load. These are only
+ useful when loading object arrays saved on Python 2 when using
+ Python 3.
Returns
-------
@@ -612,7 +622,17 @@ def read_array(fp):
# Now read the actual data.
if dtype.hasobject:
# The array contained Python objects. We need to unpickle the data.
- array = pickle.load(fp)
+ if pickle_kwargs is None:
+ pickle_kwargs = {}
+ try:
+ array = pickle.load(fp, **pickle_kwargs)
+ except UnicodeError as err:
+ if sys.version_info[0] >= 3:
+ # Friendlier error message
+ raise UnicodeError("Unpickling a python object failed: %r\n"
+ "You may need to pass the encoding= option "
+ "to numpy.load" % (err,))
+ raise
else:
if isfileobj(fp):
# We can use the fast fromfile() function.
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 2b01caed9..bf703bb76 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -164,6 +164,10 @@ class NpzFile(object):
f : BagObj instance
An object on which attribute can be performed as an alternative
to getitem access on the `NpzFile` instance itself.
+ pickle_kwargs : dict, optional
+ Additional keyword arguments to pass on to pickle.load.
+ These are only useful when loading object arrays saved on
+ Python 2 when using Python 3.
Parameters
----------
@@ -195,12 +199,13 @@ class NpzFile(object):
"""
- def __init__(self, fid, own_fid=False):
+ def __init__(self, fid, own_fid=False, pickle_kwargs=None):
# Import is postponed to here since zipfile depends on gzip, an
# optional component of the so-called standard library.
_zip = zipfile_factory(fid)
self._files = _zip.namelist()
self.files = []
+ self.pickle_kwargs = pickle_kwargs
for x in self._files:
if x.endswith('.npy'):
self.files.append(x[:-4])
@@ -256,7 +261,8 @@ class NpzFile(object):
bytes.close()
if magic == format.MAGIC_PREFIX:
bytes = self.zip.open(key)
- return format.read_array(bytes)
+ return format.read_array(bytes,
+ pickle_kwargs=self.pickle_kwargs)
else:
return self.zip.read(key)
else:
@@ -289,7 +295,7 @@ class NpzFile(object):
return self.files.__contains__(key)
-def load(file, mmap_mode=None):
+def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'):
"""
Load arrays or pickled objects from ``.npy``, ``.npz`` or pickled files.
@@ -306,6 +312,18 @@ def load(file, mmap_mode=None):
and sliced like any ndarray. Memory mapping is especially useful
for accessing small fragments of large files without reading the
entire file into memory.
+ fix_imports : bool, optional
+ Only useful when loading Python 2 generated pickled files on Python 3,
+ which includes npy/npz files containing object arrays. If `fix_imports`
+ is True, pickle will try to map the old Python 2 names to the new names
+ used in Python 3.
+ encoding : str, optional
+ What encoding to use when reading Python 2 strings. Only useful when
+ loading Python 2 generated pickled files on Python 3, which includes
+ npy/npz files containing object arrays. Values other than 'latin1',
+ 'ASCII', and 'bytes' are not allowed, as they can corrupt numerical
+ data. Default: 'ASCII'
+
Returns
-------
@@ -381,6 +399,26 @@ def load(file, mmap_mode=None):
else:
fid = file
+ if encoding not in ('ASCII', 'latin1', 'bytes'):
+ # The 'encoding' value for pickle also affects what encoding
+ # the serialized binary data of Numpy arrays is loaded
+ # in. Pickle does not pass on the encoding information to
+ # Numpy. The unpickling code in numpy.core.multiarray is
+ # written to assume that unicode data appearing where binary
+ # should be is in 'latin1'. 'bytes' is also safe, as is 'ASCII'.
+ #
+ # Other encoding values can corrupt binary data, and we
+ # purposefully disallow them. For the same reason, the errors=
+ # argument is not exposed, as values other than 'strict'
+ # result can similarly silently corrupt numerical data.
+ raise ValueError("encoding must be 'ASCII', 'latin1', or 'bytes'")
+
+ if sys.version_info[0] >= 3:
+ pickle_kwargs = dict(encoding=encoding, fix_imports=fix_imports)
+ else:
+ # Nothing to do on Python 2
+ pickle_kwargs = {}
+
try:
# Code to distinguish from NumPy binary files and pickles.
_ZIP_PREFIX = asbytes('PK\x03\x04')
@@ -392,17 +430,17 @@ def load(file, mmap_mode=None):
# Transfer file ownership to NpzFile
tmp = own_fid
own_fid = False
- return NpzFile(fid, own_fid=tmp)
+ return NpzFile(fid, own_fid=tmp, pickle_kwargs=pickle_kwargs)
elif magic == format.MAGIC_PREFIX:
# .npy file
if mmap_mode:
return format.open_memmap(file, mode=mmap_mode)
else:
- return format.read_array(fid)
+ return format.read_array(fid, pickle_kwargs=pickle_kwargs)
else:
# Try a pickle
try:
- return pickle.load(fid)
+ return pickle.load(fid, **pickle_kwargs)
except:
raise IOError(
"Failed to interpret file %s as a pickle" % repr(file))
@@ -411,7 +449,7 @@ def load(file, mmap_mode=None):
fid.close()
-def save(file, arr):
+def save(file, arr, fix_imports=True):
"""
Save an array to a binary file in NumPy ``.npy`` format.
@@ -422,6 +460,11 @@ def save(file, arr):
then the filename is unchanged. If file is a string, a ``.npy``
extension will be appended to the file name if it does not already
have one.
+ fix_imports : bool, optional
+ Only useful in forcing objects in object arrays on Python 3 to be
+ pickled in a Python 2 compatible way. If `fix_imports` is True, pickle
+ will try to map the new Python 3 names to the old module names used in
+ Python 2, so that the pickle data stream is readable with Python 2.
arr : array_like
Array data to be saved.
@@ -458,9 +501,15 @@ def save(file, arr):
else:
fid = file
+ if sys.version_info[0] >= 3:
+ pickle_kwargs = dict(fix_imports=fix_imports)
+ else:
+ # Nothing to do on Python 2
+ pickle_kwargs = None
+
try:
arr = np.asanyarray(arr)
- format.write_array(fid, arr)
+ format.write_array(fid, arr, pickle_kwargs=pickle_kwargs)
finally:
if own_fid:
fid.close()
@@ -572,7 +621,7 @@ def savez_compressed(file, *args, **kwds):
_savez(file, args, kwds, True)
-def _savez(file, args, kwds, compress):
+def _savez(file, args, kwds, compress, pickle_kwargs=None):
# Import is postponed to here since zipfile depends on gzip, an optional
# component of the so-called standard library.
import zipfile
@@ -606,7 +655,8 @@ def _savez(file, args, kwds, compress):
fname = key + '.npy'
fid = open(tmpfile, 'wb')
try:
- format.write_array(fid, np.asanyarray(val))
+ format.write_array(fid, np.asanyarray(val),
+ pickle_kwargs=pickle_kwargs)
fid.close()
fid = None
zipf.write(tmpfile, arcname=fname)
diff --git a/numpy/lib/tests/data/py2-objarr.npy b/numpy/lib/tests/data/py2-objarr.npy
new file mode 100644
index 000000000..12936c92d
--- /dev/null
+++ b/numpy/lib/tests/data/py2-objarr.npy
Binary files differ
diff --git a/numpy/lib/tests/data/py2-objarr.npz b/numpy/lib/tests/data/py2-objarr.npz
new file mode 100644
index 000000000..68a3b53a1
--- /dev/null
+++ b/numpy/lib/tests/data/py2-objarr.npz
Binary files differ
diff --git a/numpy/lib/tests/data/py3-objarr.npy b/numpy/lib/tests/data/py3-objarr.npy
new file mode 100644
index 000000000..6776074b4
--- /dev/null
+++ b/numpy/lib/tests/data/py3-objarr.npy
Binary files differ
diff --git a/numpy/lib/tests/data/py3-objarr.npz b/numpy/lib/tests/data/py3-objarr.npz
new file mode 100644
index 000000000..05eac0b76
--- /dev/null
+++ b/numpy/lib/tests/data/py3-objarr.npz
Binary files differ
diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py
index ee77386bc..169f01182 100644
--- a/numpy/lib/tests/test_format.py
+++ b/numpy/lib/tests/test_format.py
@@ -284,7 +284,7 @@ import warnings
from io import BytesIO
import numpy as np
-from numpy.compat import asbytes, asbytes_nested
+from numpy.compat import asbytes, asbytes_nested, sixu
from numpy.testing import (
run_module_suite, assert_, assert_array_equal, assert_raises, raises,
dec
@@ -534,6 +534,71 @@ def test_python2_python3_interoperability():
assert_array_equal(data, np.ones(2))
+def test_pickle_python2_python3():
+ # Test that loading object arrays saved on Python 2 works both on
+ # Python 2 and Python 3 and vice versa
+ data_dir = os.path.join(os.path.dirname(__file__), 'data')
+
+ if sys.version_info[0] >= 3:
+ xrange = range
+ else:
+ import __builtin__
+ xrange = __builtin__.xrange
+
+ expected = np.array([None, xrange, sixu('\u512a\u826f'),
+ asbytes('\xe4\xb8\x8d\xe8\x89\xaf')],
+ dtype=object)
+
+ for fname in ['py2-objarr.npy', 'py2-objarr.npz',
+ 'py3-objarr.npy', 'py3-objarr.npz']:
+ path = os.path.join(data_dir, fname)
+
+ if (fname.endswith('.npz') and sys.version_info[0] == 2 and
+ sys.version_info[1] < 7):
+ # Reading object arrays directly from zipfile appears to fail
+ # on Py2.6, see cfae0143b4
+ continue
+
+ for encoding in ['bytes', 'latin1']:
+ if (sys.version_info[0] >= 3 and sys.version_info[1] < 4 and
+ encoding == 'bytes'):
+ # The bytes encoding is available starting from Python 3.4
+ continue
+
+ data_f = np.load(path, encoding=encoding)
+ if fname.endswith('.npz'):
+ data = data_f['x']
+ data_f.close()
+ else:
+ data = data_f
+
+ if sys.version_info[0] >= 3:
+ if encoding == 'latin1' and fname.startswith('py2'):
+ assert_(isinstance(data[3], str))
+ assert_array_equal(data[:-1], expected[:-1])
+ # mojibake occurs
+ assert_array_equal(data[-1].encode(encoding), expected[-1])
+ else:
+ assert_(isinstance(data[3], bytes))
+ assert_array_equal(data, expected)
+ else:
+ assert_array_equal(data, expected)
+
+ if sys.version_info[0] >= 3:
+ if fname.startswith('py2'):
+ if fname.endswith('.npz'):
+ data = np.load(path)
+ assert_raises(UnicodeError, data.__getitem__, 'x')
+ data.close()
+ data = np.load(path, fix_imports=False, encoding='latin1')
+ assert_raises(ImportError, data.__getitem__, 'x')
+ data.close()
+ else:
+ assert_raises(UnicodeError, np.load, path)
+ assert_raises(ImportError, np.load, path,
+ encoding='latin1', fix_imports=False)
+
+
def test_version_2_0():
f = BytesIO()
# requires more than 2 byte for header