diff options
author | penghongyang <penghongyang@megvii.com> | 2023-05-09 18:24:13 +0800 |
---|---|---|
committer | penghongyang <penghongyang@megvii.com> | 2023-05-09 19:43:05 +0800 |
commit | 9d688f3ac13fe89cec59447e5550a9e548d96506 (patch) | |
tree | decf3c0149754ad29d29a90ea97bc39cd423dac5 | |
parent | 87921fe0cbdf4376e78f088898f503171a0cfcaa (diff) | |
download | numpy-9d688f3ac13fe89cec59447e5550a9e548d96506.tar.gz |
BUG: fix the method for checking local files
-rw-r--r-- | numpy/compat/py3k.py | 11 | ||||
-rw-r--r-- | numpy/compat/tests/test_compat.py | 18 | ||||
-rw-r--r-- | numpy/lib/format.py | 11 |
3 files changed, 34 insertions, 6 deletions
diff --git a/numpy/compat/py3k.py b/numpy/compat/py3k.py index 3d10bb988..6392421c8 100644 --- a/numpy/compat/py3k.py +++ b/numpy/compat/py3k.py @@ -49,6 +49,17 @@ def asstr(s): def isfileobj(f): return isinstance(f, (io.FileIO, io.BufferedReader, io.BufferedWriter)) +def _isfileobj(f): + if not isinstance(f, (io.FileIO, io.BufferedReader, io.BufferedWriter)): + return False + try: + # BufferedReader/Writer may raise OSError when + # fetching `fileno()` (e.g. when wrapping BytesIO). + f.fileno() + return True + except OSError: + return False + def open_latin1(filename, mode='r'): return open(filename, mode=mode, encoding='iso-8859-1') diff --git a/numpy/compat/tests/test_compat.py b/numpy/compat/tests/test_compat.py index 2b8acbaa0..644300329 100644 --- a/numpy/compat/tests/test_compat.py +++ b/numpy/compat/tests/test_compat.py @@ -1,6 +1,7 @@ from os.path import join +from io import BufferedReader, BytesIO -from numpy.compat import isfileobj +from numpy.compat.py3k import isfileobj, _isfileobj from numpy.testing import assert_ from numpy.testing import tempdir @@ -17,3 +18,18 @@ def test_isfileobj(): with open(filename, 'rb') as f: assert_(isfileobj(f)) + +def test__isfileobj(): + with tempdir(prefix="numpy_test_compat_") as folder: + filename = join(folder, 'a.bin') + + with open(filename, 'wb') as f: + assert_(_isfileobj(f)) + + with open(filename, 'ab') as f: + assert_(_isfileobj(f)) + + with open(filename, 'rb') as f: + assert_(_isfileobj(f)) + + assert_(_isfileobj(BufferedReader(BytesIO())) is False) diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 54fd0b0bc..9a19966eb 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -165,8 +165,9 @@ import numpy import warnings from numpy.lib.utils import safe_eval from numpy.compat import ( - isfileobj, os_fspath, pickle + os_fspath, pickle ) +from numpy.compat.py3k import _isfileobj __all__ = [] @@ -710,7 +711,7 @@ def write_array(fp, array, version=None, allow_pickle=True, pickle_kwargs=None): pickle_kwargs = {} pickle.dump(array, fp, protocol=3, **pickle_kwargs) elif array.flags.f_contiguous and not array.flags.c_contiguous: - if isfileobj(fp): + if _isfileobj(fp): array.T.tofile(fp) else: for chunk in numpy.nditer( @@ -718,7 +719,7 @@ def write_array(fp, array, version=None, allow_pickle=True, pickle_kwargs=None): buffersize=buffersize, order='F'): fp.write(chunk.tobytes('C')) else: - if isfileobj(fp): + if _isfileobj(fp): array.tofile(fp) else: for chunk in numpy.nditer( @@ -796,7 +797,7 @@ def read_array(fp, allow_pickle=False, pickle_kwargs=None, *, "You may need to pass the encoding= option " "to numpy.load" % (err,)) from err else: - if isfileobj(fp): + if _isfileobj(fp): # We can use the fast fromfile() function. array = numpy.fromfile(fp, dtype=dtype, count=count) else: @@ -888,7 +889,7 @@ def open_memmap(filename, mode='r+', dtype=None, shape=None, numpy.memmap """ - if isfileobj(filename): + if _isfileobj(filename): raise ValueError("Filename must be a string or a path-like object." " Memmap cannot use existing file handles.") |