diff options
author | Bartosz Telenczuk <muchatel@poczta.fm> | 2013-01-22 13:45:37 +0100 |
---|---|---|
committer | Bartosz Telenczuk <muchatel@poczta.fm> | 2013-06-12 13:34:28 +0200 |
commit | cfae0143b436c3296eebe71e2dd730625dcaae95 (patch) | |
tree | c4f9999bc0c073fc3c87d9a89d02a39ca694ec48 /numpy/lib | |
parent | d4b4ff038d536500e4bfd16f164d88a1a99f5ac3 (diff) | |
download | numpy-cfae0143b436c3296eebe71e2dd730625dcaae95.tar.gz |
BUG: fix loading large npz files (fixes #2922)
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/npyio.py | 12 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 19 |
2 files changed, 25 insertions, 6 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index fbcb5a46e..d400b4d30 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -244,12 +244,14 @@ class NpzFile(object): member = 1 key += '.npy' if member: - bytes = self.zip.read(key) - if bytes.startswith(format.MAGIC_PREFIX): - value = BytesIO(bytes) - return format.read_array(value) + bytes = self.zip.open(key) + magic = bytes.read(len(format.MAGIC_PREFIX)) + bytes.close() + if magic == format.MAGIC_PREFIX: + bytes = self.zip.open(key) + return format.read_array(bytes) else: - return bytes + return self.zip.read(key) else: raise KeyError("%s is not a file in the archive" % key) diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index 5987a15b0..aae95ed86 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -4,10 +4,10 @@ import sys import gzip import os import threading +from tempfile import mkstemp, mktemp, NamedTemporaryFile import time import warnings import gc -from tempfile import mkstemp, NamedTemporaryFile from io import BytesIO from datetime import datetime from numpy.testing.utils import WarningManager @@ -44,6 +44,9 @@ class TextIO(BytesIO): MAJVER, MINVER = sys.version_info[:2] +def is_64bit_platform(): + return sys.maxsize> 2**32 + def strptime(s, fmt=None): """This function is available in the datetime module only from Python >= 2.5. @@ -139,6 +142,20 @@ class TestSavezLoad(RoundtripTest, TestCase): for n, arr in enumerate(self.arr): assert_equal(arr, self.arr_reloaded['arr_%d' % n]) + + @np.testing.dec.skipif(not is_64bit_platform(), "Works only with 64bit systems") + @np.testing.dec.slow + def test_big_arrays(self): + L = 2**31+1 + tmp = mktemp(suffix='.npz') + a = np.empty(L, dtype=np.uint8) + np.savez(tmp, a=a) + del a + npfile = np.load(tmp) + a = npfile['a'] + npfile.close() + os.remove(tmp) + def test_multiple_arrays(self): a = np.array([[1, 2], [3, 4]], float) b = np.array([[1 + 2j, 2 + 7j], [3 - 6j, 4 + 12j]], complex) |