diff options
Diffstat (limited to 'logilab/common/testlib.py')
-rw-r--r-- | logilab/common/testlib.py | 213 |
1 files changed, 131 insertions, 82 deletions
diff --git a/logilab/common/testlib.py b/logilab/common/testlib.py index 8348900..f8401c4 100644 --- a/logilab/common/testlib.py +++ b/logilab/common/testlib.py @@ -64,6 +64,7 @@ import configparser from logilab.common.deprecation import class_deprecated, deprecated import unittest as unittest_legacy + if not getattr(unittest_legacy, "__package__", None): try: import unittest2 as unittest @@ -83,22 +84,22 @@ from logilab.common.decorators import cached, classproperty from logilab.common import textutils -__all__ = ['unittest_main', 'find_tests', 'nocoverage', 'pause_trace'] +__all__ = ["unittest_main", "find_tests", "nocoverage", "pause_trace"] -DEFAULT_PREFIXES = ('test', 'regrtest', 'smoketest', 'unittest', - 'func', 'validation') +DEFAULT_PREFIXES = ("test", "regrtest", "smoketest", "unittest", "func", "validation") -is_generator = deprecated('[lgc 0.63] use inspect.isgeneratorfunction')(isgeneratorfunction) +is_generator = deprecated("[lgc 0.63] use inspect.isgeneratorfunction")(isgeneratorfunction) # used by unittest to count the number of relevant levels in the traceback __unittest = 1 -@deprecated('with_tempdir is deprecated, use tempfile.TemporaryDirectory.') +@deprecated("with_tempdir is deprecated, use tempfile.TemporaryDirectory.") def with_tempdir(callable: Callable) -> Callable: """A decorator ensuring no temporary file left when the function return Work only for temporary file created with the tempfile module""" if isgeneratorfunction(callable): + def proxy(*args: Any, **kwargs: Any) -> Iterator[Union[Iterator, Iterator[str]]]: old_tmpdir = tempfile.gettempdir() new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-") @@ -111,9 +112,11 @@ def with_tempdir(callable: Callable) -> Callable: rmtree(new_tmpdir, ignore_errors=True) finally: tempfile.tempdir = old_tmpdir + return proxy else: + @wraps(callable) def proxy(*args: Any, **kargs: Any) -> Any: @@ -127,11 +130,14 @@ def with_tempdir(callable: Callable) -> Callable: rmtree(new_tmpdir, ignore_errors=True) finally: tempfile.tempdir = old_tmpdir + return proxy + def in_tempdir(callable): """A decorator moving the enclosed function inside the tempfile.tempfdir """ + @wraps(callable) def proxy(*args, **kargs): @@ -141,8 +147,10 @@ def in_tempdir(callable): return callable(*args, **kargs) finally: os.chdir(old_cwd) + return proxy + def within_tempdir(callable): """A decorator run the enclosed function inside a tmpdir removed after execution """ @@ -150,10 +158,8 @@ def within_tempdir(callable): proxy.__name__ = callable.__name__ return proxy -def find_tests(testdir, - prefixes=DEFAULT_PREFIXES, suffix=".py", - excludes=(), - remove_suffix=True): + +def find_tests(testdir, prefixes=DEFAULT_PREFIXES, suffix=".py", excludes=(), remove_suffix=True): """ Return a list of all applicable test modules. """ @@ -163,7 +169,7 @@ def find_tests(testdir, for prefix in prefixes: if name.startswith(prefix): if remove_suffix and name.endswith(suffix): - name = name[:-len(suffix)] + name = name[: -len(suffix)] if name not in excludes: tests.append(name) tests.sort() @@ -184,13 +190,12 @@ def start_interactive_mode(result): testindex = 0 print("Choose a test to debug:") # order debuggers in the same way than errors were printed - print("\n".join(['\t%s : %s' % (i, descr) for i, (_, descr) - in enumerate(descrs)])) + print("\n".join(["\t%s : %s" % (i, descr) for i, (_, descr) in enumerate(descrs)])) print("Type 'exit' (or ^D) to quit") print() try: - todebug = input('Enter a test name: ') - if todebug.strip().lower() == 'exit': + todebug = input("Enter a test name: ") + if todebug.strip().lower() == "exit": print() break else: @@ -198,7 +203,7 @@ def start_interactive_mode(result): testindex = int(todebug) debugger = debuggers[descrs[testindex][0]] except (ValueError, IndexError): - print("ERROR: invalid test number %r" % (todebug, )) + print("ERROR: invalid test number %r" % (todebug,)) else: debugger.start() except (EOFError, KeyboardInterrupt): @@ -208,6 +213,7 @@ def start_interactive_mode(result): # coverage pausing tools ##################################################### + @contextmanager def replace_trace(trace: Optional[Callable] = None) -> Iterator: """A context manager that temporary replaces the trace function""" @@ -218,8 +224,7 @@ def replace_trace(trace: Optional[Callable] = None) -> Iterator: finally: # specific hack to work around a bug in pycoverage, see # https://bitbucket.org/ned/coveragepy/issue/123 - if (oldtrace is not None and not callable(oldtrace) and - hasattr(oldtrace, 'pytrace')): + if oldtrace is not None and not callable(oldtrace) and hasattr(oldtrace, "pytrace"): oldtrace = oldtrace.pytrace sys.settrace(oldtrace) @@ -229,7 +234,7 @@ pause_trace = replace_trace def nocoverage(func: Callable) -> Callable: """Function decorator that pauses tracing functions""" - if hasattr(func, 'uncovered'): + if hasattr(func, "uncovered"): return func # mypy: "Callable[..., Any]" has no attribute "uncovered" # dynamic attribute for magic @@ -238,6 +243,7 @@ def nocoverage(func: Callable) -> Callable: def not_covered(*args: Any, **kwargs: Any) -> Any: with pause_trace(): return func(*args, **kwargs) + # mypy: "Callable[[VarArg(Any), KwArg(Any)], NoReturn]" has no attribute "uncovered" # dynamic attribute for magic not_covered.uncovered = True # type: ignore @@ -249,49 +255,56 @@ def nocoverage(func: Callable) -> Callable: # Add deprecation warnings about new api used by module level fixtures in unittest2 # http://www.voidspace.org.uk/python/articles/unittest2.shtml#setupmodule-and-teardownmodule -class _DebugResult(object): # simplify import statement among unittest flavors.. +class _DebugResult(object): # simplify import statement among unittest flavors.. "Used by the TestSuite to hold previous class when running in debug." _previousTestClass = None _moduleSetUpFailed = False shouldStop = False + # backward compatibility: TestSuite might be imported from lgc.testlib TestSuite = unittest.TestSuite + class keywords(dict): """Keyword args (**kwargs) support for generative tests.""" + class starargs(tuple): """Variable arguments (*args) for generative tests.""" + def __new__(cls, *args): return tuple.__new__(cls, args) + unittest_main = unittest.main class InnerTestSkipped(SkipTest): """raised when a test is skipped""" + pass + def parse_generative_args(params: Tuple[int, ...]) -> Tuple[Union[List[bool], List[int]], Dict]: args = [] varargs = () kwargs: Dict = {} - flags = 0 # 2 <=> starargs, 4 <=> kwargs + flags = 0 # 2 <=> starargs, 4 <=> kwargs for param in params: if isinstance(param, starargs): varargs = param if flags: - raise TypeError('found starargs after keywords !') + raise TypeError("found starargs after keywords !") flags |= 2 args += list(varargs) elif isinstance(param, keywords): kwargs = param if flags & 4: - raise TypeError('got multiple keywords parameters') + raise TypeError("got multiple keywords parameters") flags |= 4 elif flags & 2 or flags & 4: - raise TypeError('found parameters after kwargs or args') + raise TypeError("found parameters after kwargs or args") else: args.append(param) @@ -304,13 +317,14 @@ class InnerTest(tuple): instance.name = name return instance + class Tags(set): """A set of tag able validate an expression""" def __init__(self, *tags: str, **kwargs: Any) -> None: - self.inherit = kwargs.pop('inherit', True) + self.inherit = kwargs.pop("inherit", True) if kwargs: - raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys()) + raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys()) if len(tags) == 1 and not isinstance(tags[0], str): tags = tags[0] @@ -328,25 +342,26 @@ class Tags(set): # mypy: Argument 1 of "__or__" is incompatible with supertype "AbstractSet"; # mypy: supertype defines the argument type as "AbstractSet[_T]" # not sure how to fix this one - def __or__(self, other: 'Tags') -> 'Tags': # type: ignore + def __or__(self, other: "Tags") -> "Tags": # type: ignore return Tags(*super(Tags, self).__or__(other)) # duplicate definition from unittest2 of the _deprecate decorator def _deprecate(original_func): def deprecated_func(*args, **kwargs): - warnings.warn( - ('Please use %s instead.' % original_func.__name__), - DeprecationWarning, 2) + warnings.warn(("Please use %s instead." % original_func.__name__), DeprecationWarning, 2) return original_func(*args, **kwargs) + return deprecated_func + class TestCase(unittest.TestCase): """A unittest.TestCase extension with some additional methods.""" + maxDiff = None tags = Tags() - def __init__(self, methodName: str = 'runTest') -> None: + def __init__(self, methodName: str = "runTest") -> None: super(TestCase, self).__init__(methodName) self.__exc_info = sys.exc_info self.__testMethodName = self._testMethodName @@ -355,13 +370,14 @@ class TestCase(unittest.TestCase): @classproperty @cached - def datadir(cls) -> str: # pylint: disable=E0213 + def datadir(cls) -> str: # pylint: disable=E0213 """helper attribute holding the standard test's data directory NOTE: this is a logilab's standard """ mod = sys.modules[cls.__module__] - return osp.join(osp.dirname(osp.abspath(mod.__file__)), 'data') + return osp.join(osp.dirname(osp.abspath(mod.__file__)), "data") + # cache it (use a class method to cache on class since TestCase is # instantiated for each test run) @@ -392,11 +408,12 @@ class TestCase(unittest.TestCase): except (KeyboardInterrupt, SystemExit): raise except unittest.SkipTest as e: - if hasattr(result, 'addSkip'): + if hasattr(result, "addSkip"): result.addSkip(self, str(e)) else: - warnings.warn("TestResult has no addSkip method, skips not reported", - RuntimeWarning, 2) + warnings.warn( + "TestResult has no addSkip method, skips not reported", RuntimeWarning, 2 + ) result.addSuccess(self) return False except: @@ -423,23 +440,26 @@ class TestCase(unittest.TestCase): # if result.cvg: # result.cvg.start() testMethod = self._get_test_method() - if (getattr(self.__class__, "__unittest_skip__", False) or - getattr(testMethod, "__unittest_skip__", False)): + if getattr(self.__class__, "__unittest_skip__", False) or getattr( + testMethod, "__unittest_skip__", False + ): # If the class or method was skipped. try: - skip_why = (getattr(self.__class__, '__unittest_skip_why__', '') - or getattr(testMethod, '__unittest_skip_why__', '')) - if hasattr(result, 'addSkip'): + skip_why = getattr(self.__class__, "__unittest_skip_why__", "") or getattr( + testMethod, "__unittest_skip_why__", "" + ) + if hasattr(result, "addSkip"): result.addSkip(self, skip_why) else: - warnings.warn("TestResult has no addSkip method, skips not reported", - RuntimeWarning, 2) + warnings.warn( + "TestResult has no addSkip method, skips not reported", RuntimeWarning, 2 + ) result.addSuccess(self) finally: result.stopTest(self) return if runcondition and not runcondition(testMethod): - return # test is skipped + return # test is skipped result.startTest(self) try: if not self.quiet_run(result, self.setUp): @@ -447,11 +467,10 @@ class TestCase(unittest.TestCase): generative = isgeneratorfunction(testMethod) # generative tests if generative: - self._proceed_generative(result, testMethod, - runcondition) + self._proceed_generative(result, testMethod, runcondition) else: status = self._proceed(result, testMethod) - success = (status == 0) + success = status == 0 if not self.quiet_run(result, self.tearDown): return if not generative and success: @@ -461,19 +480,19 @@ class TestCase(unittest.TestCase): # result.cvg.stop() result.stopTest(self) - def _proceed_generative(self, result: Any, testfunc: Callable, runcondition: Callable = None) -> bool: + def _proceed_generative( + self, result: Any, testfunc: Callable, runcondition: Callable = None + ) -> bool: # cancel startTest()'s increment result.testsRun -= 1 success = True try: for params in testfunc(): - if runcondition and not runcondition(testfunc, - skipgenerator=False): - if not (isinstance(params, InnerTest) - and runcondition(params)): + if runcondition and not runcondition(testfunc, skipgenerator=False): + if not (isinstance(params, InnerTest) and runcondition(params)): continue if not isinstance(params, (tuple, list)): - params = (params, ) + params = (params,) func = params[0] args, kwargs = parse_generative_args(params[1:]) # increment test counter manually @@ -485,9 +504,9 @@ class TestCase(unittest.TestCase): else: success = False # XXX Don't stop anymore if an error occured - #if status == 2: + # if status == 2: # result.shouldStop = True - if result.shouldStop: # either on error or on exitfirst + error + if result.shouldStop: # either on error or on exitfirst + error break except self.failureException: result.addFailure(self, self.__exc_info()) @@ -500,7 +519,13 @@ class TestCase(unittest.TestCase): success = False return success - def _proceed(self, result: Any, testfunc: Callable, args: Union[List[bool], List[int], Tuple[()]] = (), kwargs: Optional[Dict] = None) -> int: + def _proceed( + self, + result: Any, + testfunc: Callable, + args: Union[List[bool], List[int], Tuple[()]] = (), + kwargs: Optional[Dict] = None, + ) -> int: """proceed the actual test returns 0 on success, 1 on failure, 2 on error @@ -529,39 +554,40 @@ class TestCase(unittest.TestCase): def innerSkip(self, msg: str = None) -> NoReturn: """mark a generative test as skipped for the <msg> reason""" - msg = msg or 'test was skipped' + msg = msg or "test was skipped" raise InnerTestSkipped(msg) - if sys.version_info >= (3,2): + if sys.version_info >= (3, 2): assertItemsEqual = unittest.TestCase.assertCountEqual else: assertCountEqual = unittest.TestCase.assertItemsEqual -TestCase.assertItemsEqual = deprecated('assertItemsEqual is deprecated, use assertCountEqual')( - TestCase.assertItemsEqual) + +TestCase.assertItemsEqual = deprecated("assertItemsEqual is deprecated, use assertCountEqual")( + TestCase.assertItemsEqual +) import doctest + class SkippedSuite(unittest.TestSuite): def test(self): """just there to trigger test execution""" - self.skipped_test('doctest module has no DocTestSuite class') + self.skipped_test("doctest module has no DocTestSuite class") class DocTestFinder(doctest.DocTestFinder): - def __init__(self, *args, **kwargs): - self.skipped = kwargs.pop('skipped', ()) + self.skipped = kwargs.pop("skipped", ()) doctest.DocTestFinder.__init__(self, *args, **kwargs) def _get_test(self, obj, name, module, globs, source_lines): """override default _get_test method to be able to skip tests according to skipped attribute's value """ - if getattr(obj, '__name__', '') in self.skipped: + if getattr(obj, "__name__", "") in self.skipped: return None - return doctest.DocTestFinder._get_test(self, obj, name, module, - globs, source_lines) + return doctest.DocTestFinder._get_test(self, obj, name, module, globs, source_lines) # mypy error: Invalid metaclass 'class_deprecated' @@ -571,10 +597,11 @@ class DocTest(TestCase, metaclass=class_deprecated): # type: ignore I don't know how to make unittest.main consider the DocTestSuite instance without this hack """ - __deprecation_warning__ = 'use stdlib doctest module with unittest API directly' + + __deprecation_warning__ = "use stdlib doctest module with unittest API directly" skipped = () - def __call__(self, result=None, runcondition=None, options=None):\ - # pylint: disable=W0613 + + def __call__(self, result=None, runcondition=None, options=None): # pylint: disable=W0613 try: finder = DocTestFinder(skipped=self.skipped) suite = doctest.DocTestSuite(self.module, test_finder=finder) @@ -590,6 +617,7 @@ class DocTest(TestCase, metaclass=class_deprecated): # type: ignore finally: builtins.__dict__.clear() builtins.__dict__.update(old_builtins) + run = __call__ def test(self): @@ -607,21 +635,27 @@ class MockConnection: def cursor(self): """Mock cursor method""" return self + def execute(self, query, args=None): """Mock execute method""" - self.received.append( (query, args) ) + self.received.append((query, args)) + def fetchone(self): """Mock fetchone method""" return self.results[0] + def fetchall(self): """Mock fetchall method""" return self.results + def commit(self): """Mock commiy method""" - self.states.append( ('commit', len(self.received)) ) + self.states.append(("commit", len(self.received))) + def rollback(self): """Mock rollback method""" - self.states.append( ('rollback', len(self.received)) ) + self.states.append(("rollback", len(self.received))) + def close(self): """Mock close method""" pass @@ -629,7 +663,7 @@ class MockConnection: # mypy error: Name 'Mock' is not defined # dynamic class created by this class -def mock_object(**params: Any) -> 'Mock': # type: ignore +def mock_object(**params: Any) -> "Mock": # type: ignore """creates an object using params to set attributes >>> option = mock_object(verbose=False, index=range(5)) >>> option.verbose @@ -637,7 +671,7 @@ def mock_object(**params: Any) -> 'Mock': # type: ignore >>> option.index [0, 1, 2, 3, 4] """ - return type('Mock', (), params)() + return type("Mock", (), params)() def create_files(paths: List[str], chroot: str) -> None: @@ -664,7 +698,7 @@ def create_files(paths: List[str], chroot: str) -> None: path = osp.join(chroot, path) filename = osp.basename(path) # path is a directory path - if filename == '': + if filename == "": dirs.add(path) # path is a filename path else: @@ -674,54 +708,69 @@ def create_files(paths: List[str], chroot: str) -> None: if not osp.isdir(dirpath): os.makedirs(dirpath) for filepath in files: - open(filepath, 'w').close() + open(filepath, "w").close() -class AttrObject: # XXX cf mock_object +class AttrObject: # XXX cf mock_object def __init__(self, **kwargs): self.__dict__.update(kwargs) + def tag(*args: str, **kwargs: Any) -> Callable: """descriptor adding tag to a function""" + def desc(func: Callable) -> Callable: - assert not hasattr(func, 'tags') + assert not hasattr(func, "tags") # mypy: "Callable[..., Any]" has no attribute "tags" # dynamic magic attribute func.tags = Tags(*args, **kwargs) # type: ignore return func + return desc + def require_version(version: str) -> Callable: """ Compare version of python interpreter to the given one. Skip the test if older. """ + def check_require_version(f: Callable) -> Callable: - version_elements = version.split('.') + version_elements = version.split(".") try: compare = tuple([int(v) for v in version_elements]) except ValueError: - raise ValueError('%s is not a correct version : should be X.Y[.Z].' % version) + raise ValueError("%s is not a correct version : should be X.Y[.Z]." % version) current = sys.version_info[:3] if current < compare: + def new_f(self, *args, **kwargs): - self.skipTest('Need at least %s version of python. Current version is %s.' % (version, '.'.join([str(element) for element in current]))) + self.skipTest( + "Need at least %s version of python. Current version is %s." + % (version, ".".join([str(element) for element in current])) + ) + new_f.__name__ = f.__name__ return new_f else: return f + return check_require_version + def require_module(module: str) -> Callable: """ Check if the given module is loaded. Skip the test if not. """ + def check_require_module(f: Callable) -> Callable: try: __import__(module) return f except ImportError: + def new_f(self, *args, **kwargs): - self.skipTest('%s can not be imported.' % module) + self.skipTest("%s can not be imported." % module) + new_f.__name__ = f.__name__ return new_f - return check_require_module + return check_require_module |