diff options
-rw-r--r-- | doc/release/1.10.0-notes.rst | 8 | ||||
-rw-r--r-- | numpy/lib/format.py | 28 | ||||
-rw-r--r-- | numpy/lib/npyio.py | 70 | ||||
-rw-r--r-- | numpy/lib/tests/data/py2-objarr.npy | bin | 0 -> 258 bytes | |||
-rw-r--r-- | numpy/lib/tests/data/py2-objarr.npz | bin | 0 -> 366 bytes | |||
-rw-r--r-- | numpy/lib/tests/data/py3-objarr.npy | bin | 0 -> 341 bytes | |||
-rw-r--r-- | numpy/lib/tests/data/py3-objarr.npz | bin | 0 -> 449 bytes | |||
-rw-r--r-- | numpy/lib/tests/test_format.py | 67 |
8 files changed, 158 insertions, 15 deletions
diff --git a/doc/release/1.10.0-notes.rst b/doc/release/1.10.0-notes.rst index 95cea6537..c38a2ae64 100644 --- a/doc/release/1.10.0-notes.rst +++ b/doc/release/1.10.0-notes.rst @@ -178,6 +178,14 @@ compare NaNs as equal by setting ``equal_nan=True``. Subclasses, such as crashed with an ``OverflowError`` in these cases). Integers larger than ``2**63-1`` are converted to floating-point values. +*np.load*, *np.save* have pickle backward compatibility flags +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The functions *np.load* and *np.save* have additional keyword +arguments for controlling backward compatibility of pickled Python +objects. This enables Numpy on Python 3 to load npy files containing +object arrays that were generated on Python 2. + Changes ======= diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 4ff0a660f..1a2133aa9 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -517,7 +517,7 @@ def _read_array_header(fp, version): return d['shape'], d['fortran_order'], dtype -def write_array(fp, array, version=None): +def write_array(fp, array, version=None, pickle_kwargs=None): """ Write an array to an NPY file, including a header. @@ -535,6 +535,10 @@ def write_array(fp, array, version=None): version : (int, int) or None, optional The version number of the format. None means use the oldest supported version that is able to store the data. Default: None + pickle_kwargs : dict, optional + Additional keyword arguments to pass to pickle.dump, excluding + 'protocol'. These are only useful when pickling objects in object + arrays on Python 3 to Python 2 compatible format. Raises ------ @@ -561,7 +565,9 @@ def write_array(fp, array, version=None): # We contain Python objects so we cannot write out the data # directly. Instead, we will pickle it out with version 2 of the # pickle protocol. - pickle.dump(array, fp, protocol=2) + if pickle_kwargs is None: + pickle_kwargs = {} + pickle.dump(array, fp, protocol=2, **pickle_kwargs) elif array.flags.f_contiguous and not array.flags.c_contiguous: if isfileobj(fp): array.T.tofile(fp) @@ -580,7 +586,7 @@ def write_array(fp, array, version=None): fp.write(chunk.tobytes('C')) -def read_array(fp): +def read_array(fp, pickle_kwargs=None): """ Read an array from an NPY file. @@ -589,6 +595,10 @@ def read_array(fp): fp : file_like object If this is not a real file object, then this may take extra memory and time. + pickle_kwargs : dict + Additional keyword arguments to pass to pickle.load. These are only + useful when loading object arrays saved on Python 2 when using + Python 3. Returns ------- @@ -612,7 +622,17 @@ def read_array(fp): # Now read the actual data. if dtype.hasobject: # The array contained Python objects. We need to unpickle the data. - array = pickle.load(fp) + if pickle_kwargs is None: + pickle_kwargs = {} + try: + array = pickle.load(fp, **pickle_kwargs) + except UnicodeError as err: + if sys.version_info[0] >= 3: + # Friendlier error message + raise UnicodeError("Unpickling a python object failed: %r\n" + "You may need to pass the encoding= option " + "to numpy.load" % (err,)) + raise else: if isfileobj(fp): # We can use the fast fromfile() function. diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 2b01caed9..bf703bb76 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -164,6 +164,10 @@ class NpzFile(object): f : BagObj instance An object on which attribute can be performed as an alternative to getitem access on the `NpzFile` instance itself. + pickle_kwargs : dict, optional + Additional keyword arguments to pass on to pickle.load. + These are only useful when loading object arrays saved on + Python 2 when using Python 3. Parameters ---------- @@ -195,12 +199,13 @@ class NpzFile(object): """ - def __init__(self, fid, own_fid=False): + def __init__(self, fid, own_fid=False, pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an # optional component of the so-called standard library. _zip = zipfile_factory(fid) self._files = _zip.namelist() self.files = [] + self.pickle_kwargs = pickle_kwargs for x in self._files: if x.endswith('.npy'): self.files.append(x[:-4]) @@ -256,7 +261,8 @@ class NpzFile(object): bytes.close() if magic == format.MAGIC_PREFIX: bytes = self.zip.open(key) - return format.read_array(bytes) + return format.read_array(bytes, + pickle_kwargs=self.pickle_kwargs) else: return self.zip.read(key) else: @@ -289,7 +295,7 @@ class NpzFile(object): return self.files.__contains__(key) -def load(file, mmap_mode=None): +def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): """ Load arrays or pickled objects from ``.npy``, ``.npz`` or pickled files. @@ -306,6 +312,18 @@ def load(file, mmap_mode=None): and sliced like any ndarray. Memory mapping is especially useful for accessing small fragments of large files without reading the entire file into memory. + fix_imports : bool, optional + Only useful when loading Python 2 generated pickled files on Python 3, + which includes npy/npz files containing object arrays. If `fix_imports` + is True, pickle will try to map the old Python 2 names to the new names + used in Python 3. + encoding : str, optional + What encoding to use when reading Python 2 strings. Only useful when + loading Python 2 generated pickled files on Python 3, which includes + npy/npz files containing object arrays. Values other than 'latin1', + 'ASCII', and 'bytes' are not allowed, as they can corrupt numerical + data. Default: 'ASCII' + Returns ------- @@ -381,6 +399,26 @@ def load(file, mmap_mode=None): else: fid = file + if encoding not in ('ASCII', 'latin1', 'bytes'): + # The 'encoding' value for pickle also affects what encoding + # the serialized binary data of Numpy arrays is loaded + # in. Pickle does not pass on the encoding information to + # Numpy. The unpickling code in numpy.core.multiarray is + # written to assume that unicode data appearing where binary + # should be is in 'latin1'. 'bytes' is also safe, as is 'ASCII'. + # + # Other encoding values can corrupt binary data, and we + # purposefully disallow them. For the same reason, the errors= + # argument is not exposed, as values other than 'strict' + # result can similarly silently corrupt numerical data. + raise ValueError("encoding must be 'ASCII', 'latin1', or 'bytes'") + + if sys.version_info[0] >= 3: + pickle_kwargs = dict(encoding=encoding, fix_imports=fix_imports) + else: + # Nothing to do on Python 2 + pickle_kwargs = {} + try: # Code to distinguish from NumPy binary files and pickles. _ZIP_PREFIX = asbytes('PK\x03\x04') @@ -392,17 +430,17 @@ def load(file, mmap_mode=None): # Transfer file ownership to NpzFile tmp = own_fid own_fid = False - return NpzFile(fid, own_fid=tmp) + return NpzFile(fid, own_fid=tmp, pickle_kwargs=pickle_kwargs) elif magic == format.MAGIC_PREFIX: # .npy file if mmap_mode: return format.open_memmap(file, mode=mmap_mode) else: - return format.read_array(fid) + return format.read_array(fid, pickle_kwargs=pickle_kwargs) else: # Try a pickle try: - return pickle.load(fid) + return pickle.load(fid, **pickle_kwargs) except: raise IOError( "Failed to interpret file %s as a pickle" % repr(file)) @@ -411,7 +449,7 @@ def load(file, mmap_mode=None): fid.close() -def save(file, arr): +def save(file, arr, fix_imports=True): """ Save an array to a binary file in NumPy ``.npy`` format. @@ -422,6 +460,11 @@ def save(file, arr): then the filename is unchanged. If file is a string, a ``.npy`` extension will be appended to the file name if it does not already have one. + fix_imports : bool, optional + Only useful in forcing objects in object arrays on Python 3 to be + pickled in a Python 2 compatible way. If `fix_imports` is True, pickle + will try to map the new Python 3 names to the old module names used in + Python 2, so that the pickle data stream is readable with Python 2. arr : array_like Array data to be saved. @@ -458,9 +501,15 @@ def save(file, arr): else: fid = file + if sys.version_info[0] >= 3: + pickle_kwargs = dict(fix_imports=fix_imports) + else: + # Nothing to do on Python 2 + pickle_kwargs = None + try: arr = np.asanyarray(arr) - format.write_array(fid, arr) + format.write_array(fid, arr, pickle_kwargs=pickle_kwargs) finally: if own_fid: fid.close() @@ -572,7 +621,7 @@ def savez_compressed(file, *args, **kwds): _savez(file, args, kwds, True) -def _savez(file, args, kwds, compress): +def _savez(file, args, kwds, compress, pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an optional # component of the so-called standard library. import zipfile @@ -606,7 +655,8 @@ def _savez(file, args, kwds, compress): fname = key + '.npy' fid = open(tmpfile, 'wb') try: - format.write_array(fid, np.asanyarray(val)) + format.write_array(fid, np.asanyarray(val), + pickle_kwargs=pickle_kwargs) fid.close() fid = None zipf.write(tmpfile, arcname=fname) diff --git a/numpy/lib/tests/data/py2-objarr.npy b/numpy/lib/tests/data/py2-objarr.npy Binary files differnew file mode 100644 index 000000000..12936c92d --- /dev/null +++ b/numpy/lib/tests/data/py2-objarr.npy diff --git a/numpy/lib/tests/data/py2-objarr.npz b/numpy/lib/tests/data/py2-objarr.npz Binary files differnew file mode 100644 index 000000000..68a3b53a1 --- /dev/null +++ b/numpy/lib/tests/data/py2-objarr.npz diff --git a/numpy/lib/tests/data/py3-objarr.npy b/numpy/lib/tests/data/py3-objarr.npy Binary files differnew file mode 100644 index 000000000..6776074b4 --- /dev/null +++ b/numpy/lib/tests/data/py3-objarr.npy diff --git a/numpy/lib/tests/data/py3-objarr.npz b/numpy/lib/tests/data/py3-objarr.npz Binary files differnew file mode 100644 index 000000000..05eac0b76 --- /dev/null +++ b/numpy/lib/tests/data/py3-objarr.npz diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py index ee77386bc..169f01182 100644 --- a/numpy/lib/tests/test_format.py +++ b/numpy/lib/tests/test_format.py @@ -284,7 +284,7 @@ import warnings from io import BytesIO import numpy as np -from numpy.compat import asbytes, asbytes_nested +from numpy.compat import asbytes, asbytes_nested, sixu from numpy.testing import ( run_module_suite, assert_, assert_array_equal, assert_raises, raises, dec @@ -534,6 +534,71 @@ def test_python2_python3_interoperability(): assert_array_equal(data, np.ones(2)) +def test_pickle_python2_python3(): + # Test that loading object arrays saved on Python 2 works both on + # Python 2 and Python 3 and vice versa + data_dir = os.path.join(os.path.dirname(__file__), 'data') + + if sys.version_info[0] >= 3: + xrange = range + else: + import __builtin__ + xrange = __builtin__.xrange + + expected = np.array([None, xrange, sixu('\u512a\u826f'), + asbytes('\xe4\xb8\x8d\xe8\x89\xaf')], + dtype=object) + + for fname in ['py2-objarr.npy', 'py2-objarr.npz', + 'py3-objarr.npy', 'py3-objarr.npz']: + path = os.path.join(data_dir, fname) + + if (fname.endswith('.npz') and sys.version_info[0] == 2 and + sys.version_info[1] < 7): + # Reading object arrays directly from zipfile appears to fail + # on Py2.6, see cfae0143b4 + continue + + for encoding in ['bytes', 'latin1']: + if (sys.version_info[0] >= 3 and sys.version_info[1] < 4 and + encoding == 'bytes'): + # The bytes encoding is available starting from Python 3.4 + continue + + data_f = np.load(path, encoding=encoding) + if fname.endswith('.npz'): + data = data_f['x'] + data_f.close() + else: + data = data_f + + if sys.version_info[0] >= 3: + if encoding == 'latin1' and fname.startswith('py2'): + assert_(isinstance(data[3], str)) + assert_array_equal(data[:-1], expected[:-1]) + # mojibake occurs + assert_array_equal(data[-1].encode(encoding), expected[-1]) + else: + assert_(isinstance(data[3], bytes)) + assert_array_equal(data, expected) + else: + assert_array_equal(data, expected) + + if sys.version_info[0] >= 3: + if fname.startswith('py2'): + if fname.endswith('.npz'): + data = np.load(path) + assert_raises(UnicodeError, data.__getitem__, 'x') + data.close() + data = np.load(path, fix_imports=False, encoding='latin1') + assert_raises(ImportError, data.__getitem__, 'x') + data.close() + else: + assert_raises(UnicodeError, np.load, path) + assert_raises(ImportError, np.load, path, + encoding='latin1', fix_imports=False) + + def test_version_2_0(): f = BytesIO() # requires more than 2 byte for header |