summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorNoƩ Rubinstein <noe.rubinstein@gmail.com>2023-01-27 19:39:38 +0100
committerGitHub <noreply@github.com>2023-01-27 19:39:38 +0100
commit41499995a4c532556b2b6d6e3c0fabb0a7bdb61a (patch)
tree9776bca523d90e02388442a38740a04eb4d41731 /numpy/lib
parent2289292db6a19f2bbfddd3dea3790ffa19955333 (diff)
downloadnumpy-41499995a4c532556b2b6d6e3c0fabb0a7bdb61a.tar.gz
API: Raise EOFError when trying to load past the end of a `.npy` file (#23105)
Currently, the following code: ``` import numpy as np with open('foo.npy', 'wb') as f: for i in range(np.random.randint(10)): np.save(f, 1) with open('foo.npy', 'rb') as f: while True: np.load(f) ``` Will raise: ``` ValueError: Cannot load file containing pickled data when allow_pickle=False ``` While there is no pickled data in the file.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/npyio.py5
-rw-r--r--numpy/lib/tests/test_io.py10
2 files changed, 15 insertions, 0 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 0c1740df1..568195e18 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -327,6 +327,9 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True,
If ``allow_pickle=True``, but the file cannot be loaded as a pickle.
ValueError
The file contains an object array, but ``allow_pickle=False`` given.
+ EOFError
+ When calling ``np.load`` multiple times on the same file handle,
+ if all data has already been read
See Also
--------
@@ -410,6 +413,8 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True,
_ZIP_SUFFIX = b'PK\x05\x06' # empty zip files start with this
N = len(format.MAGIC_PREFIX)
magic = fid.read(N)
+ if not magic:
+ raise EOFError("No data left in file")
# If the file size is less than N, we need to make sure not
# to seek past the beginning of the file
fid.seek(-min(N, len(magic)), 1) # back-up
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index 4699935ca..cffe7e7ac 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -2737,3 +2737,13 @@ def test_load_refcount():
with assert_no_gc_cycles():
x = np.loadtxt(TextIO("0 1 2 3"), dtype=dt)
assert_equal(x, np.array([((0, 1), (2, 3))], dtype=dt))
+
+def test_load_multiple_arrays_until_eof():
+ f = BytesIO()
+ np.save(f, 1)
+ np.save(f, 2)
+ f.seek(0)
+ assert np.load(f) == 1
+ assert np.load(f) == 2
+ with pytest.raises(EOFError):
+ np.load(f)