summaryrefslogtreecommitdiff
path: root/Lib
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2016-02-11 13:21:30 +0200
committerSerhiy Storchaka <storchaka@gmail.com>2016-02-11 13:21:30 +0200
commit8ffbed9e0d231761a83ffe3547d4e2d28291782e (patch)
tree5afd310c8c281218ae2eb1bf7530ce18eb6999ae /Lib
parent20ac73fb7133c119974eb21693d2621b09585ca0 (diff)
downloadcpython-8ffbed9e0d231761a83ffe3547d4e2d28291782e.tar.gz
Issue #25994: Added the close() method and the support of the context manager
protocol for the os.scandir() iterator.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/os.py94
-rw-r--r--Lib/test/test_os.py52
2 files changed, 107 insertions, 39 deletions
diff --git a/Lib/os.py b/Lib/os.py
index 674a7d7efd..c3f674ec3d 100644
--- a/Lib/os.py
+++ b/Lib/os.py
@@ -374,46 +374,47 @@ def walk(top, topdown=True, onerror=None, followlinks=False):
onerror(error)
return
- while True:
- try:
+ with scandir_it:
+ while True:
try:
- entry = next(scandir_it)
- except StopIteration:
- break
- except OSError as error:
- if onerror is not None:
- onerror(error)
- return
-
- try:
- is_dir = entry.is_dir()
- except OSError:
- # If is_dir() raises an OSError, consider that the entry is not
- # a directory, same behaviour than os.path.isdir().
- is_dir = False
-
- if is_dir:
- dirs.append(entry.name)
- else:
- nondirs.append(entry.name)
+ try:
+ entry = next(scandir_it)
+ except StopIteration:
+ break
+ except OSError as error:
+ if onerror is not None:
+ onerror(error)
+ return
- if not topdown and is_dir:
- # Bottom-up: recurse into sub-directory, but exclude symlinks to
- # directories if followlinks is False
- if followlinks:
- walk_into = True
+ try:
+ is_dir = entry.is_dir()
+ except OSError:
+ # If is_dir() raises an OSError, consider that the entry is not
+ # a directory, same behaviour than os.path.isdir().
+ is_dir = False
+
+ if is_dir:
+ dirs.append(entry.name)
else:
- try:
- is_symlink = entry.is_symlink()
- except OSError:
- # If is_symlink() raises an OSError, consider that the
- # entry is not a symbolic link, same behaviour than
- # os.path.islink().
- is_symlink = False
- walk_into = not is_symlink
+ nondirs.append(entry.name)
- if walk_into:
- yield from walk(entry.path, topdown, onerror, followlinks)
+ if not topdown and is_dir:
+ # Bottom-up: recurse into sub-directory, but exclude symlinks to
+ # directories if followlinks is False
+ if followlinks:
+ walk_into = True
+ else:
+ try:
+ is_symlink = entry.is_symlink()
+ except OSError:
+ # If is_symlink() raises an OSError, consider that the
+ # entry is not a symbolic link, same behaviour than
+ # os.path.islink().
+ is_symlink = False
+ walk_into = not is_symlink
+
+ if walk_into:
+ yield from walk(entry.path, topdown, onerror, followlinks)
# Yield before recursion if going top down
if topdown:
@@ -437,15 +438,30 @@ class _DummyDirEntry:
def __init__(self, dir, name):
self.name = name
self.path = path.join(dir, name)
+
def is_dir(self):
return path.isdir(self.path)
+
def is_symlink(self):
return path.islink(self.path)
-def _dummy_scandir(dir):
+class _dummy_scandir:
# listdir-based implementation for bytes patches on Windows
- for name in listdir(dir):
- yield _DummyDirEntry(dir, name)
+ def __init__(self, dir):
+ self.dir = dir
+ self.it = iter(listdir(dir))
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return _DummyDirEntry(self.dir, next(self.it))
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ self.it = iter(())
__all__.append("walk")
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
index 66a426a4f1..07682f2bf9 100644
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -2808,6 +2808,8 @@ class ExportsTests(unittest.TestCase):
class TestScandir(unittest.TestCase):
+ check_no_resource_warning = support.check_no_resource_warning
+
def setUp(self):
self.path = os.path.realpath(support.TESTFN)
self.addCleanup(support.rmtree, self.path)
@@ -3030,6 +3032,56 @@ class TestScandir(unittest.TestCase):
for obj in [1234, 1.234, {}, []]:
self.assertRaises(TypeError, os.scandir, obj)
+ def test_close(self):
+ self.create_file("file.txt")
+ self.create_file("file2.txt")
+ iterator = os.scandir(self.path)
+ next(iterator)
+ iterator.close()
+ # multiple closes
+ iterator.close()
+ with self.check_no_resource_warning():
+ del iterator
+
+ def test_context_manager(self):
+ self.create_file("file.txt")
+ self.create_file("file2.txt")
+ with os.scandir(self.path) as iterator:
+ next(iterator)
+ with self.check_no_resource_warning():
+ del iterator
+
+ def test_context_manager_close(self):
+ self.create_file("file.txt")
+ self.create_file("file2.txt")
+ with os.scandir(self.path) as iterator:
+ next(iterator)
+ iterator.close()
+
+ def test_context_manager_exception(self):
+ self.create_file("file.txt")
+ self.create_file("file2.txt")
+ with self.assertRaises(ZeroDivisionError):
+ with os.scandir(self.path) as iterator:
+ next(iterator)
+ 1/0
+ with self.check_no_resource_warning():
+ del iterator
+
+ def test_resource_warning(self):
+ self.create_file("file.txt")
+ self.create_file("file2.txt")
+ iterator = os.scandir(self.path)
+ next(iterator)
+ with self.assertWarns(ResourceWarning):
+ del iterator
+ support.gc_collect()
+ # exhausted iterator
+ iterator = os.scandir(self.path)
+ list(iterator)
+ with self.check_no_resource_warning():
+ del iterator
+
if __name__ == "__main__":
unittest.main()