summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--distutils/filelist.py29
-rw-r--r--distutils/tests/test_filelist.py10
2 files changed, 38 insertions, 1 deletions
diff --git a/distutils/filelist.py b/distutils/filelist.py
index 1d5e4c87..82a77384 100644
--- a/distutils/filelist.py
+++ b/distutils/filelist.py
@@ -247,14 +247,41 @@ def _find_all_simple(path):
"""
Find all files under 'path'
"""
+ all_unique = _UniqueDirs.filter(os.walk(path, followlinks=True))
results = (
os.path.join(base, file)
- for base, dirs, files in os.walk(path, followlinks=True)
+ for base, dirs, files in all_unique
for file in files
)
return filter(os.path.isfile, results)
+class _UniqueDirs(set):
+ """
+ Exclude previously-seen dirs from walk results,
+ avoiding infinite recursion.
+ Ref https://bugs.python.org/issue44497.
+ """
+ def __call__(self, walk_item):
+ """
+ Given an item from an os.walk result, determine
+ if the item represents a unique dir for this instance
+ and if not, prevent further traversal.
+ """
+ base, dirs, files = walk_item
+ stat = os.stat(base)
+ candidate = stat.st_dev, stat.st_ino
+ found = candidate in self
+ if found:
+ del dirs[:]
+ self.add(candidate)
+ return not found
+
+ @classmethod
+ def filter(cls, items):
+ return filter(cls(), items)
+
+
def findall(dir=os.curdir):
"""
Find all files under 'dir' and return the list of full filenames.
diff --git a/distutils/tests/test_filelist.py b/distutils/tests/test_filelist.py
index d8e4b39f..9ec507b5 100644
--- a/distutils/tests/test_filelist.py
+++ b/distutils/tests/test_filelist.py
@@ -331,6 +331,16 @@ class FindAllTestCase(unittest.TestCase):
expected = [file1]
self.assertEqual(filelist.findall(temp_dir), expected)
+ @os_helper.skip_unless_symlink
+ def test_symlink_loop(self):
+ with os_helper.temp_dir() as temp_dir:
+ link = os.path.join(temp_dir, 'link-to-parent')
+ content = os.path.join(temp_dir, 'somefile')
+ os_helper.create_empty_file(content)
+ os.symlink('.', link)
+ files = filelist.findall(temp_dir)
+ assert len(files) == 1
+
def test_suite():
return unittest.TestSuite([