summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Doc/library/unittest.rst6
-rw-r--r--Lib/unittest/__init__.py9
-rw-r--r--Lib/unittest/loader.py162
-rw-r--r--Lib/unittest/test/test_discovery.py45
-rw-r--r--Lib/unittest/test/test_loader.py2
-rw-r--r--Misc/NEWS2
6 files changed, 167 insertions, 59 deletions
diff --git a/Doc/library/unittest.rst b/Doc/library/unittest.rst
index 355e31f94a..341c7acb6b 100644
--- a/Doc/library/unittest.rst
+++ b/Doc/library/unittest.rst
@@ -1668,7 +1668,11 @@ Loading and running tests
If a package (a directory containing a file named :file:`__init__.py`) is
found, the package will be checked for a ``load_tests`` function. If this
- exists then it will be called with *loader*, *tests*, *pattern*.
+ exists then it will be called
+ ``package.load_tests(loader, tests, pattern)``. Test discovery takes care
+ to ensure that a package is only checked for tests once during an
+ invocation, even if the load_tests function itself calls
+ ``loader.discover``.
If ``load_tests`` exists then discovery does *not* recurse into the
package, ``load_tests`` is responsible for loading all tests in the
diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py
index a5d50af78f..f6d7ae278b 100644
--- a/Lib/unittest/__init__.py
+++ b/Lib/unittest/__init__.py
@@ -67,3 +67,12 @@ from .signals import installHandler, registerResult, removeResult, removeHandler
# deprecated
_TextTestResult = TextTestResult
+
+# There are no tests here, so don't try to run anything discovered from
+# introspecting the symbols (e.g. FunctionTestCase). Instead, all our
+# tests come from within unittest.test.
+def load_tests(loader, tests, pattern):
+ import os.path
+ # top level directory cached on loader instance
+ this_dir = os.path.dirname(__file__)
+ return loader.discover(start_dir=this_dir, pattern=pattern)
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index 811bedf928..8c10ad1f41 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -65,6 +65,9 @@ class TestLoader(object):
def __init__(self):
super(TestLoader, self).__init__()
self.errors = []
+ # Tracks packages which we have called into via load_tests, to
+ # avoid infinite re-entrancy.
+ self._loading_packages = set()
def loadTestsFromTestCase(self, testCaseClass):
"""Return a suite of all tests cases contained in testCaseClass"""
@@ -229,9 +232,13 @@ class TestLoader(object):
If a test package name (directory with '__init__.py') matches the
pattern then the package will be checked for a 'load_tests' function. If
- this exists then it will be called with loader, tests, pattern.
+ this exists then it will be called with (loader, tests, pattern) unless
+ the package has already had load_tests called from the same discovery
+ invocation, in which case the package module object is not scanned for
+ tests - this ensures that when a package uses discover to further
+ discover child tests that infinite recursion does not happen.
- If load_tests exists then discovery does *not* recurse into the package,
+ If load_tests exists then discovery does *not* recurse into the package,
load_tests is responsible for loading all tests in the package.
The pattern is deliberately not stored as a loader attribute so that
@@ -355,69 +362,110 @@ class TestLoader(object):
def _find_tests(self, start_dir, pattern, namespace=False):
"""Used by discovery. Yields test suites it loads."""
+ # Handle the __init__ in this package
+ name = self._get_name_from_path(start_dir)
+ # name is '.' when start_dir == top_level_dir (and top_level_dir is by
+ # definition not a package).
+ if name != '.' and name not in self._loading_packages:
+ # name is in self._loading_packages while we have called into
+ # loadTestsFromModule with name.
+ tests, should_recurse = self._find_test_path(
+ start_dir, pattern, namespace)
+ if tests is not None:
+ yield tests
+ if not should_recurse:
+ # Either an error occured, or load_tests was used by the
+ # package.
+ return
+ # Handle the contents.
paths = sorted(os.listdir(start_dir))
-
for path in paths:
full_path = os.path.join(start_dir, path)
- if os.path.isfile(full_path):
- if not VALID_MODULE_NAME.match(path):
- # valid Python identifiers only
- continue
- if not self._match_path(path, full_path, pattern):
- continue
- # if the test file matches, load it
+ tests, should_recurse = self._find_test_path(
+ full_path, pattern, namespace)
+ if tests is not None:
+ yield tests
+ if should_recurse:
+ # we found a package that didn't use load_tests.
name = self._get_name_from_path(full_path)
+ self._loading_packages.add(name)
try:
- module = self._get_module_from_name(name)
- except case.SkipTest as e:
- yield _make_skipped_test(name, e, self.suiteClass)
- except:
- error_case, error_message = \
- _make_failed_import_test(name, self.suiteClass)
- self.errors.append(error_message)
- yield error_case
- else:
- mod_file = os.path.abspath(getattr(module, '__file__', full_path))
- realpath = _jython_aware_splitext(os.path.realpath(mod_file))
- fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path))
- if realpath.lower() != fullpath_noext.lower():
- module_dir = os.path.dirname(realpath)
- mod_name = _jython_aware_splitext(os.path.basename(full_path))
- expected_dir = os.path.dirname(full_path)
- msg = ("%r module incorrectly imported from %r. Expected %r. "
- "Is this module globally installed?")
- raise ImportError(msg % (mod_name, module_dir, expected_dir))
- yield self.loadTestsFromModule(module, pattern=pattern)
- elif os.path.isdir(full_path):
- if (not namespace and
- not os.path.isfile(os.path.join(full_path, '__init__.py'))):
- continue
-
- load_tests = None
- tests = None
- name = self._get_name_from_path(full_path)
+ yield from self._find_tests(full_path, pattern, namespace)
+ finally:
+ self._loading_packages.discard(name)
+
+ def _find_test_path(self, full_path, pattern, namespace=False):
+ """Used by discovery.
+
+ Loads tests from a single file, or a directories' __init__.py when
+ passed the directory.
+
+ Returns a tuple (None_or_tests_from_file, should_recurse).
+ """
+ basename = os.path.basename(full_path)
+ if os.path.isfile(full_path):
+ if not VALID_MODULE_NAME.match(basename):
+ # valid Python identifiers only
+ return None, False
+ if not self._match_path(basename, full_path, pattern):
+ return None, False
+ # if the test file matches, load it
+ name = self._get_name_from_path(full_path)
+ try:
+ module = self._get_module_from_name(name)
+ except case.SkipTest as e:
+ return _make_skipped_test(name, e, self.suiteClass), False
+ except:
+ error_case, error_message = \
+ _make_failed_import_test(name, self.suiteClass)
+ self.errors.append(error_message)
+ return error_case, False
+ else:
+ mod_file = os.path.abspath(
+ getattr(module, '__file__', full_path))
+ realpath = _jython_aware_splitext(
+ os.path.realpath(mod_file))
+ fullpath_noext = _jython_aware_splitext(
+ os.path.realpath(full_path))
+ if realpath.lower() != fullpath_noext.lower():
+ module_dir = os.path.dirname(realpath)
+ mod_name = _jython_aware_splitext(
+ os.path.basename(full_path))
+ expected_dir = os.path.dirname(full_path)
+ msg = ("%r module incorrectly imported from %r. Expected "
+ "%r. Is this module globally installed?")
+ raise ImportError(
+ msg % (mod_name, module_dir, expected_dir))
+ return self.loadTestsFromModule(module, pattern=pattern), False
+ elif os.path.isdir(full_path):
+ if (not namespace and
+ not os.path.isfile(os.path.join(full_path, '__init__.py'))):
+ return None, False
+
+ load_tests = None
+ tests = None
+ name = self._get_name_from_path(full_path)
+ try:
+ package = self._get_module_from_name(name)
+ except case.SkipTest as e:
+ return _make_skipped_test(name, e, self.suiteClass), False
+ except:
+ error_case, error_message = \
+ _make_failed_import_test(name, self.suiteClass)
+ self.errors.append(error_message)
+ return error_case, False
+ else:
+ load_tests = getattr(package, 'load_tests', None)
+ # Mark this package as being in load_tests (possibly ;))
+ self._loading_packages.add(name)
try:
- package = self._get_module_from_name(name)
- except case.SkipTest as e:
- yield _make_skipped_test(name, e, self.suiteClass)
- except:
- error_case, error_message = \
- _make_failed_import_test(name, self.suiteClass)
- self.errors.append(error_message)
- yield error_case
- else:
- load_tests = getattr(package, 'load_tests', None)
tests = self.loadTestsFromModule(package, pattern=pattern)
- if tests is not None:
- # tests loaded from package file
- yield tests
-
if load_tests is not None:
- # loadTestsFromModule(package) has load_tests for us.
- continue
- # recurse into the package
- yield from self._find_tests(full_path, pattern,
- namespace=namespace)
+ # loadTestsFromModule(package) has loaded tests for us.
+ return tests, False
+ return tests, True
+ finally:
+ self._loading_packages.discard(name)
defaultTestLoader = TestLoader()
diff --git a/Lib/unittest/test/test_discovery.py b/Lib/unittest/test/test_discovery.py
index 92b983a527..4f61314ec6 100644
--- a/Lib/unittest/test/test_discovery.py
+++ b/Lib/unittest/test/test_discovery.py
@@ -368,6 +368,51 @@ class TestDiscovery(unittest.TestCase):
self.assertEqual(_find_tests_args, [(start_dir, 'pattern')])
self.assertIn(top_level_dir, sys.path)
+ def test_discover_start_dir_is_package_calls_package_load_tests(self):
+ # This test verifies that the package load_tests in a package is indeed
+ # invoked when the start_dir is a package (and not the top level).
+ # http://bugs.python.org/issue22457
+
+ # Test data: we expect the following:
+ # an isfile to verify the package, then importing and scanning
+ # as per _find_tests' normal behaviour.
+ # We expect to see our load_tests hook called once.
+ vfs = {abspath('/toplevel'): ['startdir'],
+ abspath('/toplevel/startdir'): ['__init__.py']}
+ def list_dir(path):
+ return list(vfs[path])
+ self.addCleanup(setattr, os, 'listdir', os.listdir)
+ os.listdir = list_dir
+ self.addCleanup(setattr, os.path, 'isfile', os.path.isfile)
+ os.path.isfile = lambda path: path.endswith('.py')
+ self.addCleanup(setattr, os.path, 'isdir', os.path.isdir)
+ os.path.isdir = lambda path: not path.endswith('.py')
+ self.addCleanup(sys.path.remove, abspath('/toplevel'))
+
+ class Module(object):
+ paths = []
+ load_tests_args = []
+
+ def __init__(self, path):
+ self.path = path
+
+ def load_tests(self, loader, tests, pattern):
+ return ['load_tests called ' + self.path]
+
+ def __eq__(self, other):
+ return self.path == other.path
+
+ loader = unittest.TestLoader()
+ loader._get_module_from_name = lambda name: Module(name)
+ loader.suiteClass = lambda thing: thing
+
+ suite = loader.discover('/toplevel/startdir', top_level_dir='/toplevel')
+
+ # We should have loaded tests from the package __init__.
+ # (normally this would be nested TestSuites.)
+ self.assertEqual(suite,
+ [['load_tests called startdir']])
+
def setup_import_issue_tests(self, fakefile):
listdir = os.listdir
os.listdir = lambda _: [fakefile]
diff --git a/Lib/unittest/test/test_loader.py b/Lib/unittest/test/test_loader.py
index c489730232..68f1036111 100644
--- a/Lib/unittest/test/test_loader.py
+++ b/Lib/unittest/test/test_loader.py
@@ -841,7 +841,7 @@ class Test_TestLoader(unittest.TestCase):
loader = unittest.TestLoader()
suite = loader.loadTestsFromNames(
- ['unittest.loader.sdasfasfasdf', 'unittest'])
+ ['unittest.loader.sdasfasfasdf', 'unittest.test.dummy'])
error, test = self.check_deferred_error(loader, list(suite)[0])
expected = "module 'unittest.loader' has no attribute 'sdasfasfasdf'"
self.assertIn(
diff --git a/Misc/NEWS b/Misc/NEWS
index 27e8c8e4f6..f69aa8bd05 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -212,6 +212,8 @@ Library
- Issue #22217: Implemented reprs of classes in the zipfile module.
+- Issue #22457: Honour load_tests in the start_dir of discovery.
+
- Issue #18216: gettext now raises an error when a .mo file has an
unsupported major version number. Patch by Aaron Hill.