From 8ffbed9e0d231761a83ffe3547d4e2d28291782e Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Thu, 11 Feb 2016 13:21:30 +0200 Subject: Issue #25994: Added the close() method and the support of the context manager protocol for the os.scandir() iterator. --- Lib/os.py | 94 +++++++++++++++++++++++++++++++---------------------- Lib/test/test_os.py | 52 +++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 39 deletions(-) (limited to 'Lib') diff --git a/Lib/os.py b/Lib/os.py index 674a7d7efd..c3f674ec3d 100644 --- a/Lib/os.py +++ b/Lib/os.py @@ -374,46 +374,47 @@ def walk(top, topdown=True, onerror=None, followlinks=False): onerror(error) return - while True: - try: + with scandir_it: + while True: try: - entry = next(scandir_it) - except StopIteration: - break - except OSError as error: - if onerror is not None: - onerror(error) - return - - try: - is_dir = entry.is_dir() - except OSError: - # If is_dir() raises an OSError, consider that the entry is not - # a directory, same behaviour than os.path.isdir(). - is_dir = False - - if is_dir: - dirs.append(entry.name) - else: - nondirs.append(entry.name) + try: + entry = next(scandir_it) + except StopIteration: + break + except OSError as error: + if onerror is not None: + onerror(error) + return - if not topdown and is_dir: - # Bottom-up: recurse into sub-directory, but exclude symlinks to - # directories if followlinks is False - if followlinks: - walk_into = True + try: + is_dir = entry.is_dir() + except OSError: + # If is_dir() raises an OSError, consider that the entry is not + # a directory, same behaviour than os.path.isdir(). + is_dir = False + + if is_dir: + dirs.append(entry.name) else: - try: - is_symlink = entry.is_symlink() - except OSError: - # If is_symlink() raises an OSError, consider that the - # entry is not a symbolic link, same behaviour than - # os.path.islink(). - is_symlink = False - walk_into = not is_symlink + nondirs.append(entry.name) - if walk_into: - yield from walk(entry.path, topdown, onerror, followlinks) + if not topdown and is_dir: + # Bottom-up: recurse into sub-directory, but exclude symlinks to + # directories if followlinks is False + if followlinks: + walk_into = True + else: + try: + is_symlink = entry.is_symlink() + except OSError: + # If is_symlink() raises an OSError, consider that the + # entry is not a symbolic link, same behaviour than + # os.path.islink(). + is_symlink = False + walk_into = not is_symlink + + if walk_into: + yield from walk(entry.path, topdown, onerror, followlinks) # Yield before recursion if going top down if topdown: @@ -437,15 +438,30 @@ class _DummyDirEntry: def __init__(self, dir, name): self.name = name self.path = path.join(dir, name) + def is_dir(self): return path.isdir(self.path) + def is_symlink(self): return path.islink(self.path) -def _dummy_scandir(dir): +class _dummy_scandir: # listdir-based implementation for bytes patches on Windows - for name in listdir(dir): - yield _DummyDirEntry(dir, name) + def __init__(self, dir): + self.dir = dir + self.it = iter(listdir(dir)) + + def __iter__(self): + return self + + def __next__(self): + return _DummyDirEntry(self.dir, next(self.it)) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.it = iter(()) __all__.append("walk") diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index 66a426a4f1..07682f2bf9 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -2808,6 +2808,8 @@ class ExportsTests(unittest.TestCase): class TestScandir(unittest.TestCase): + check_no_resource_warning = support.check_no_resource_warning + def setUp(self): self.path = os.path.realpath(support.TESTFN) self.addCleanup(support.rmtree, self.path) @@ -3030,6 +3032,56 @@ class TestScandir(unittest.TestCase): for obj in [1234, 1.234, {}, []]: self.assertRaises(TypeError, os.scandir, obj) + def test_close(self): + self.create_file("file.txt") + self.create_file("file2.txt") + iterator = os.scandir(self.path) + next(iterator) + iterator.close() + # multiple closes + iterator.close() + with self.check_no_resource_warning(): + del iterator + + def test_context_manager(self): + self.create_file("file.txt") + self.create_file("file2.txt") + with os.scandir(self.path) as iterator: + next(iterator) + with self.check_no_resource_warning(): + del iterator + + def test_context_manager_close(self): + self.create_file("file.txt") + self.create_file("file2.txt") + with os.scandir(self.path) as iterator: + next(iterator) + iterator.close() + + def test_context_manager_exception(self): + self.create_file("file.txt") + self.create_file("file2.txt") + with self.assertRaises(ZeroDivisionError): + with os.scandir(self.path) as iterator: + next(iterator) + 1/0 + with self.check_no_resource_warning(): + del iterator + + def test_resource_warning(self): + self.create_file("file.txt") + self.create_file("file2.txt") + iterator = os.scandir(self.path) + next(iterator) + with self.assertWarns(ResourceWarning): + del iterator + support.gc_collect() + # exhausted iterator + iterator = os.scandir(self.path) + list(iterator) + with self.check_no_resource_warning(): + del iterator + if __name__ == "__main__": unittest.main() -- cgit v1.2.1