diff options
Diffstat (limited to 'bzrlib/tests/TestUtil.py')
-rw-r--r-- | bzrlib/tests/TestUtil.py | 233 |
1 files changed, 233 insertions, 0 deletions
diff --git a/bzrlib/tests/TestUtil.py b/bzrlib/tests/TestUtil.py new file mode 100644 index 0000000..f54d9bc --- /dev/null +++ b/bzrlib/tests/TestUtil.py @@ -0,0 +1,233 @@ +# Copyright (C) 2005-2011 Canonical Ltd +# Author: Robert Collins <robert.collins@canonical.com> +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +# + +import sys +import logging +import unittest +import weakref + +from bzrlib import pyutils + +# Mark this python module as being part of the implementation +# of unittest: this gives us better tracebacks where the last +# shown frame is the test code, not our assertXYZ. +__unittest = 1 + + +class LogCollector(logging.Handler): + + def __init__(self): + logging.Handler.__init__(self) + self.records=[] + + def emit(self, record): + self.records.append(record.getMessage()) + + +def makeCollectingLogger(): + """I make a logger instance that collects its logs for programmatic analysis + -> (logger, collector)""" + logger=logging.Logger("collector") + handler=LogCollector() + handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) + logger.addHandler(handler) + return logger, handler + + +def visitTests(suite, visitor): + """A foreign method for visiting the tests in a test suite.""" + for test in suite._tests: + #Abusing types to avoid monkey patching unittest.TestCase. + # Maybe that would be better? + try: + test.visit(visitor) + except AttributeError: + if isinstance(test, unittest.TestCase): + visitor.visitCase(test) + elif isinstance(test, unittest.TestSuite): + visitor.visitSuite(test) + visitTests(test, visitor) + else: + print "unvisitable non-unittest.TestCase element %r (%r)" % ( + test, test.__class__) + + +class FailedCollectionCase(unittest.TestCase): + """Pseudo-test to run and report failure if given case was uncollected""" + + def __init__(self, case): + super(FailedCollectionCase, self).__init__("fail_uncollected") + # GZ 2011-09-16: Maybe catch errors from id() method as cases may be + # in a bit of a funny state by now. + self._problem_case_id = case.id() + + def id(self): + if self._problem_case_id[-1:] == ")": + return self._problem_case_id[:-1] + ",uncollected)" + return self._problem_case_id + "(uncollected)" + + def fail_uncollected(self): + self.fail("Uncollected test case: " + self._problem_case_id) + + +class TestSuite(unittest.TestSuite): + """I am an extended TestSuite with a visitor interface. + This is primarily to allow filtering of tests - and suites or + more in the future. An iterator of just tests wouldn't scale...""" + + def visit(self, visitor): + """visit the composite. Visiting is depth-first. + current callbacks are visitSuite and visitCase.""" + visitor.visitSuite(self) + visitTests(self, visitor) + + def run(self, result): + """Run the tests in the suite, discarding references after running.""" + tests = list(self) + tests.reverse() + self._tests = [] + stored_count = 0 + count_stored_tests = getattr(result, "_count_stored_tests", int) + from bzrlib.tests import selftest_debug_flags + notify = "uncollected_cases" in selftest_debug_flags + while tests: + if result.shouldStop: + self._tests = reversed(tests) + break + case = _run_and_collect_case(tests.pop(), result)() + new_stored_count = count_stored_tests() + if case is not None and isinstance(case, unittest.TestCase): + if stored_count == new_stored_count and notify: + # Testcase didn't fail, but somehow is still alive + FailedCollectionCase(case).run(result) + # Adding a new failure so need to reupdate the count + new_stored_count = count_stored_tests() + # GZ 2011-09-16: Previously zombied the case at this point by + # clearing the dict as fallback, skip for now. + stored_count = new_stored_count + return result + + +def _run_and_collect_case(case, res): + """Run test case against result and use weakref to drop the refcount""" + case.run(res) + return weakref.ref(case) + + +class TestLoader(unittest.TestLoader): + """Custom TestLoader to extend the stock python one.""" + + suiteClass = TestSuite + # Memoize test names by test class dict + test_func_names = {} + + def loadTestsFromModuleNames(self, names): + """use a custom means to load tests from modules. + + There is an undesirable glitch in the python TestLoader where a + import error is ignore. We think this can be solved by ensuring the + requested name is resolvable, if its not raising the original error. + """ + result = self.suiteClass() + for name in names: + result.addTests(self.loadTestsFromModuleName(name)) + return result + + def loadTestsFromModuleName(self, name): + result = self.suiteClass() + module = pyutils.get_named_object(name) + + result.addTests(self.loadTestsFromModule(module)) + return result + + def loadTestsFromModule(self, module): + """Load tests from a module object. + + This extension of the python test loader looks for an attribute + load_tests in the module object, and if not found falls back to the + regular python loadTestsFromModule. + + If a load_tests attribute is found, it is called and the result is + returned. + + load_tests should be defined like so: + >>> def load_tests(standard_tests, module, loader): + >>> pass + + standard_tests is the tests found by the stock TestLoader in the + module, module and loader are the module and loader instances. + + For instance, to run every test twice, you might do: + >>> def load_tests(standard_tests, module, loader): + >>> result = loader.suiteClass() + >>> for test in iter_suite_tests(standard_tests): + >>> result.addTests([test, test]) + >>> return result + """ + if sys.version_info < (2, 7): + basic_tests = super(TestLoader, self).loadTestsFromModule(module) + else: + # GZ 2010-07-19: Python 2.7 unittest also uses load_tests but with + # a different and incompatible signature + basic_tests = super(TestLoader, self).loadTestsFromModule(module, + use_load_tests=False) + load_tests = getattr(module, "load_tests", None) + if load_tests is not None: + return load_tests(basic_tests, module, self) + else: + return basic_tests + + def getTestCaseNames(self, test_case_class): + test_fn_names = self.test_func_names.get(test_case_class, None) + if test_fn_names is not None: + # We already know them + return test_fn_names + + test_fn_names = unittest.TestLoader.getTestCaseNames(self, + test_case_class) + self.test_func_names[test_case_class] = test_fn_names + return test_fn_names + + +class FilteredByModuleTestLoader(TestLoader): + """A test loader that import only the needed modules.""" + + def __init__(self, needs_module): + """Constructor. + + :param needs_module: a callable taking a module name as a + parameter returing True if the module should be loaded. + """ + TestLoader.__init__(self) + self.needs_module = needs_module + + def loadTestsFromModuleName(self, name): + if self.needs_module(name): + return TestLoader.loadTestsFromModuleName(self, name) + else: + return self.suiteClass() + + +class TestVisitor(object): + """A visitor for Tests""" + + def visitSuite(self, aTestSuite): + pass + + def visitCase(self, aTestCase): + pass |