diff options
author | NoƩ Rubinstein <noe.rubinstein@gmail.com> | 2023-01-27 19:39:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-27 19:39:38 +0100 |
commit | 41499995a4c532556b2b6d6e3c0fabb0a7bdb61a (patch) | |
tree | 9776bca523d90e02388442a38740a04eb4d41731 /numpy/lib | |
parent | 2289292db6a19f2bbfddd3dea3790ffa19955333 (diff) | |
download | numpy-41499995a4c532556b2b6d6e3c0fabb0a7bdb61a.tar.gz |
API: Raise EOFError when trying to load past the end of a `.npy` file (#23105)
Currently, the following code:
```
import numpy as np
with open('foo.npy', 'wb') as f:
for i in range(np.random.randint(10)):
np.save(f, 1)
with open('foo.npy', 'rb') as f:
while True:
np.load(f)
```
Will raise:
```
ValueError: Cannot load file containing pickled data when allow_pickle=False
```
While there is no pickled data in the file.
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/npyio.py | 5 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 10 |
2 files changed, 15 insertions, 0 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 0c1740df1..568195e18 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -327,6 +327,9 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True, If ``allow_pickle=True``, but the file cannot be loaded as a pickle. ValueError The file contains an object array, but ``allow_pickle=False`` given. + EOFError + When calling ``np.load`` multiple times on the same file handle, + if all data has already been read See Also -------- @@ -410,6 +413,8 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True, _ZIP_SUFFIX = b'PK\x05\x06' # empty zip files start with this N = len(format.MAGIC_PREFIX) magic = fid.read(N) + if not magic: + raise EOFError("No data left in file") # If the file size is less than N, we need to make sure not # to seek past the beginning of the file fid.seek(-min(N, len(magic)), 1) # back-up diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index 4699935ca..cffe7e7ac 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -2737,3 +2737,13 @@ def test_load_refcount(): with assert_no_gc_cycles(): x = np.loadtxt(TextIO("0 1 2 3"), dtype=dt) assert_equal(x, np.array([((0, 1), (2, 3))], dtype=dt)) + +def test_load_multiple_arrays_until_eof(): + f = BytesIO() + np.save(f, 1) + np.save(f, 2) + f.seek(0) + assert np.load(f) == 1 + assert np.load(f) == 2 + with pytest.raises(EOFError): + np.load(f) |