summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpenghongyang <penghongyang@megvii.com>2023-05-09 18:24:13 +0800
committerpenghongyang <penghongyang@megvii.com>2023-05-09 19:43:05 +0800
commit9d688f3ac13fe89cec59447e5550a9e548d96506 (patch)
treedecf3c0149754ad29d29a90ea97bc39cd423dac5
parent87921fe0cbdf4376e78f088898f503171a0cfcaa (diff)
downloadnumpy-9d688f3ac13fe89cec59447e5550a9e548d96506.tar.gz
BUG: fix the method for checking local files
-rw-r--r--numpy/compat/py3k.py11
-rw-r--r--numpy/compat/tests/test_compat.py18
-rw-r--r--numpy/lib/format.py11
3 files changed, 34 insertions, 6 deletions
diff --git a/numpy/compat/py3k.py b/numpy/compat/py3k.py
index 3d10bb988..6392421c8 100644
--- a/numpy/compat/py3k.py
+++ b/numpy/compat/py3k.py
@@ -49,6 +49,17 @@ def asstr(s):
def isfileobj(f):
return isinstance(f, (io.FileIO, io.BufferedReader, io.BufferedWriter))
+def _isfileobj(f):
+ if not isinstance(f, (io.FileIO, io.BufferedReader, io.BufferedWriter)):
+ return False
+ try:
+ # BufferedReader/Writer may raise OSError when
+ # fetching `fileno()` (e.g. when wrapping BytesIO).
+ f.fileno()
+ return True
+ except OSError:
+ return False
+
def open_latin1(filename, mode='r'):
return open(filename, mode=mode, encoding='iso-8859-1')
diff --git a/numpy/compat/tests/test_compat.py b/numpy/compat/tests/test_compat.py
index 2b8acbaa0..644300329 100644
--- a/numpy/compat/tests/test_compat.py
+++ b/numpy/compat/tests/test_compat.py
@@ -1,6 +1,7 @@
from os.path import join
+from io import BufferedReader, BytesIO
-from numpy.compat import isfileobj
+from numpy.compat.py3k import isfileobj, _isfileobj
from numpy.testing import assert_
from numpy.testing import tempdir
@@ -17,3 +18,18 @@ def test_isfileobj():
with open(filename, 'rb') as f:
assert_(isfileobj(f))
+
+def test__isfileobj():
+ with tempdir(prefix="numpy_test_compat_") as folder:
+ filename = join(folder, 'a.bin')
+
+ with open(filename, 'wb') as f:
+ assert_(_isfileobj(f))
+
+ with open(filename, 'ab') as f:
+ assert_(_isfileobj(f))
+
+ with open(filename, 'rb') as f:
+ assert_(_isfileobj(f))
+
+ assert_(_isfileobj(BufferedReader(BytesIO())) is False)
diff --git a/numpy/lib/format.py b/numpy/lib/format.py
index 54fd0b0bc..9a19966eb 100644
--- a/numpy/lib/format.py
+++ b/numpy/lib/format.py
@@ -165,8 +165,9 @@ import numpy
import warnings
from numpy.lib.utils import safe_eval
from numpy.compat import (
- isfileobj, os_fspath, pickle
+ os_fspath, pickle
)
+from numpy.compat.py3k import _isfileobj
__all__ = []
@@ -710,7 +711,7 @@ def write_array(fp, array, version=None, allow_pickle=True, pickle_kwargs=None):
pickle_kwargs = {}
pickle.dump(array, fp, protocol=3, **pickle_kwargs)
elif array.flags.f_contiguous and not array.flags.c_contiguous:
- if isfileobj(fp):
+ if _isfileobj(fp):
array.T.tofile(fp)
else:
for chunk in numpy.nditer(
@@ -718,7 +719,7 @@ def write_array(fp, array, version=None, allow_pickle=True, pickle_kwargs=None):
buffersize=buffersize, order='F'):
fp.write(chunk.tobytes('C'))
else:
- if isfileobj(fp):
+ if _isfileobj(fp):
array.tofile(fp)
else:
for chunk in numpy.nditer(
@@ -796,7 +797,7 @@ def read_array(fp, allow_pickle=False, pickle_kwargs=None, *,
"You may need to pass the encoding= option "
"to numpy.load" % (err,)) from err
else:
- if isfileobj(fp):
+ if _isfileobj(fp):
# We can use the fast fromfile() function.
array = numpy.fromfile(fp, dtype=dtype, count=count)
else:
@@ -888,7 +889,7 @@ def open_memmap(filename, mode='r+', dtype=None, shape=None,
numpy.memmap
"""
- if isfileobj(filename):
+ if _isfileobj(filename):
raise ValueError("Filename must be a string or a path-like object."
" Memmap cannot use existing file handles.")