summaryrefslogtreecommitdiff
path: root/logilab/common/testlib.py
diff options
context:
space:
mode:
Diffstat (limited to 'logilab/common/testlib.py')
-rw-r--r--logilab/common/testlib.py213
1 files changed, 131 insertions, 82 deletions
diff --git a/logilab/common/testlib.py b/logilab/common/testlib.py
index 8348900..f8401c4 100644
--- a/logilab/common/testlib.py
+++ b/logilab/common/testlib.py
@@ -64,6 +64,7 @@ import configparser
from logilab.common.deprecation import class_deprecated, deprecated
import unittest as unittest_legacy
+
if not getattr(unittest_legacy, "__package__", None):
try:
import unittest2 as unittest
@@ -83,22 +84,22 @@ from logilab.common.decorators import cached, classproperty
from logilab.common import textutils
-__all__ = ['unittest_main', 'find_tests', 'nocoverage', 'pause_trace']
+__all__ = ["unittest_main", "find_tests", "nocoverage", "pause_trace"]
-DEFAULT_PREFIXES = ('test', 'regrtest', 'smoketest', 'unittest',
- 'func', 'validation')
+DEFAULT_PREFIXES = ("test", "regrtest", "smoketest", "unittest", "func", "validation")
-is_generator = deprecated('[lgc 0.63] use inspect.isgeneratorfunction')(isgeneratorfunction)
+is_generator = deprecated("[lgc 0.63] use inspect.isgeneratorfunction")(isgeneratorfunction)
# used by unittest to count the number of relevant levels in the traceback
__unittest = 1
-@deprecated('with_tempdir is deprecated, use tempfile.TemporaryDirectory.')
+@deprecated("with_tempdir is deprecated, use tempfile.TemporaryDirectory.")
def with_tempdir(callable: Callable) -> Callable:
"""A decorator ensuring no temporary file left when the function return
Work only for temporary file created with the tempfile module"""
if isgeneratorfunction(callable):
+
def proxy(*args: Any, **kwargs: Any) -> Iterator[Union[Iterator, Iterator[str]]]:
old_tmpdir = tempfile.gettempdir()
new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-")
@@ -111,9 +112,11 @@ def with_tempdir(callable: Callable) -> Callable:
rmtree(new_tmpdir, ignore_errors=True)
finally:
tempfile.tempdir = old_tmpdir
+
return proxy
else:
+
@wraps(callable)
def proxy(*args: Any, **kargs: Any) -> Any:
@@ -127,11 +130,14 @@ def with_tempdir(callable: Callable) -> Callable:
rmtree(new_tmpdir, ignore_errors=True)
finally:
tempfile.tempdir = old_tmpdir
+
return proxy
+
def in_tempdir(callable):
"""A decorator moving the enclosed function inside the tempfile.tempfdir
"""
+
@wraps(callable)
def proxy(*args, **kargs):
@@ -141,8 +147,10 @@ def in_tempdir(callable):
return callable(*args, **kargs)
finally:
os.chdir(old_cwd)
+
return proxy
+
def within_tempdir(callable):
"""A decorator run the enclosed function inside a tmpdir removed after execution
"""
@@ -150,10 +158,8 @@ def within_tempdir(callable):
proxy.__name__ = callable.__name__
return proxy
-def find_tests(testdir,
- prefixes=DEFAULT_PREFIXES, suffix=".py",
- excludes=(),
- remove_suffix=True):
+
+def find_tests(testdir, prefixes=DEFAULT_PREFIXES, suffix=".py", excludes=(), remove_suffix=True):
"""
Return a list of all applicable test modules.
"""
@@ -163,7 +169,7 @@ def find_tests(testdir,
for prefix in prefixes:
if name.startswith(prefix):
if remove_suffix and name.endswith(suffix):
- name = name[:-len(suffix)]
+ name = name[: -len(suffix)]
if name not in excludes:
tests.append(name)
tests.sort()
@@ -184,13 +190,12 @@ def start_interactive_mode(result):
testindex = 0
print("Choose a test to debug:")
# order debuggers in the same way than errors were printed
- print("\n".join(['\t%s : %s' % (i, descr) for i, (_, descr)
- in enumerate(descrs)]))
+ print("\n".join(["\t%s : %s" % (i, descr) for i, (_, descr) in enumerate(descrs)]))
print("Type 'exit' (or ^D) to quit")
print()
try:
- todebug = input('Enter a test name: ')
- if todebug.strip().lower() == 'exit':
+ todebug = input("Enter a test name: ")
+ if todebug.strip().lower() == "exit":
print()
break
else:
@@ -198,7 +203,7 @@ def start_interactive_mode(result):
testindex = int(todebug)
debugger = debuggers[descrs[testindex][0]]
except (ValueError, IndexError):
- print("ERROR: invalid test number %r" % (todebug, ))
+ print("ERROR: invalid test number %r" % (todebug,))
else:
debugger.start()
except (EOFError, KeyboardInterrupt):
@@ -208,6 +213,7 @@ def start_interactive_mode(result):
# coverage pausing tools #####################################################
+
@contextmanager
def replace_trace(trace: Optional[Callable] = None) -> Iterator:
"""A context manager that temporary replaces the trace function"""
@@ -218,8 +224,7 @@ def replace_trace(trace: Optional[Callable] = None) -> Iterator:
finally:
# specific hack to work around a bug in pycoverage, see
# https://bitbucket.org/ned/coveragepy/issue/123
- if (oldtrace is not None and not callable(oldtrace) and
- hasattr(oldtrace, 'pytrace')):
+ if oldtrace is not None and not callable(oldtrace) and hasattr(oldtrace, "pytrace"):
oldtrace = oldtrace.pytrace
sys.settrace(oldtrace)
@@ -229,7 +234,7 @@ pause_trace = replace_trace
def nocoverage(func: Callable) -> Callable:
"""Function decorator that pauses tracing functions"""
- if hasattr(func, 'uncovered'):
+ if hasattr(func, "uncovered"):
return func
# mypy: "Callable[..., Any]" has no attribute "uncovered"
# dynamic attribute for magic
@@ -238,6 +243,7 @@ def nocoverage(func: Callable) -> Callable:
def not_covered(*args: Any, **kwargs: Any) -> Any:
with pause_trace():
return func(*args, **kwargs)
+
# mypy: "Callable[[VarArg(Any), KwArg(Any)], NoReturn]" has no attribute "uncovered"
# dynamic attribute for magic
not_covered.uncovered = True # type: ignore
@@ -249,49 +255,56 @@ def nocoverage(func: Callable) -> Callable:
# Add deprecation warnings about new api used by module level fixtures in unittest2
# http://www.voidspace.org.uk/python/articles/unittest2.shtml#setupmodule-and-teardownmodule
-class _DebugResult(object): # simplify import statement among unittest flavors..
+class _DebugResult(object): # simplify import statement among unittest flavors..
"Used by the TestSuite to hold previous class when running in debug."
_previousTestClass = None
_moduleSetUpFailed = False
shouldStop = False
+
# backward compatibility: TestSuite might be imported from lgc.testlib
TestSuite = unittest.TestSuite
+
class keywords(dict):
"""Keyword args (**kwargs) support for generative tests."""
+
class starargs(tuple):
"""Variable arguments (*args) for generative tests."""
+
def __new__(cls, *args):
return tuple.__new__(cls, args)
+
unittest_main = unittest.main
class InnerTestSkipped(SkipTest):
"""raised when a test is skipped"""
+
pass
+
def parse_generative_args(params: Tuple[int, ...]) -> Tuple[Union[List[bool], List[int]], Dict]:
args = []
varargs = ()
kwargs: Dict = {}
- flags = 0 # 2 <=> starargs, 4 <=> kwargs
+ flags = 0 # 2 <=> starargs, 4 <=> kwargs
for param in params:
if isinstance(param, starargs):
varargs = param
if flags:
- raise TypeError('found starargs after keywords !')
+ raise TypeError("found starargs after keywords !")
flags |= 2
args += list(varargs)
elif isinstance(param, keywords):
kwargs = param
if flags & 4:
- raise TypeError('got multiple keywords parameters')
+ raise TypeError("got multiple keywords parameters")
flags |= 4
elif flags & 2 or flags & 4:
- raise TypeError('found parameters after kwargs or args')
+ raise TypeError("found parameters after kwargs or args")
else:
args.append(param)
@@ -304,13 +317,14 @@ class InnerTest(tuple):
instance.name = name
return instance
+
class Tags(set):
"""A set of tag able validate an expression"""
def __init__(self, *tags: str, **kwargs: Any) -> None:
- self.inherit = kwargs.pop('inherit', True)
+ self.inherit = kwargs.pop("inherit", True)
if kwargs:
- raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys())
+ raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys())
if len(tags) == 1 and not isinstance(tags[0], str):
tags = tags[0]
@@ -328,25 +342,26 @@ class Tags(set):
# mypy: Argument 1 of "__or__" is incompatible with supertype "AbstractSet";
# mypy: supertype defines the argument type as "AbstractSet[_T]"
# not sure how to fix this one
- def __or__(self, other: 'Tags') -> 'Tags': # type: ignore
+ def __or__(self, other: "Tags") -> "Tags": # type: ignore
return Tags(*super(Tags, self).__or__(other))
# duplicate definition from unittest2 of the _deprecate decorator
def _deprecate(original_func):
def deprecated_func(*args, **kwargs):
- warnings.warn(
- ('Please use %s instead.' % original_func.__name__),
- DeprecationWarning, 2)
+ warnings.warn(("Please use %s instead." % original_func.__name__), DeprecationWarning, 2)
return original_func(*args, **kwargs)
+
return deprecated_func
+
class TestCase(unittest.TestCase):
"""A unittest.TestCase extension with some additional methods."""
+
maxDiff = None
tags = Tags()
- def __init__(self, methodName: str = 'runTest') -> None:
+ def __init__(self, methodName: str = "runTest") -> None:
super(TestCase, self).__init__(methodName)
self.__exc_info = sys.exc_info
self.__testMethodName = self._testMethodName
@@ -355,13 +370,14 @@ class TestCase(unittest.TestCase):
@classproperty
@cached
- def datadir(cls) -> str: # pylint: disable=E0213
+ def datadir(cls) -> str: # pylint: disable=E0213
"""helper attribute holding the standard test's data directory
NOTE: this is a logilab's standard
"""
mod = sys.modules[cls.__module__]
- return osp.join(osp.dirname(osp.abspath(mod.__file__)), 'data')
+ return osp.join(osp.dirname(osp.abspath(mod.__file__)), "data")
+
# cache it (use a class method to cache on class since TestCase is
# instantiated for each test run)
@@ -392,11 +408,12 @@ class TestCase(unittest.TestCase):
except (KeyboardInterrupt, SystemExit):
raise
except unittest.SkipTest as e:
- if hasattr(result, 'addSkip'):
+ if hasattr(result, "addSkip"):
result.addSkip(self, str(e))
else:
- warnings.warn("TestResult has no addSkip method, skips not reported",
- RuntimeWarning, 2)
+ warnings.warn(
+ "TestResult has no addSkip method, skips not reported", RuntimeWarning, 2
+ )
result.addSuccess(self)
return False
except:
@@ -423,23 +440,26 @@ class TestCase(unittest.TestCase):
# if result.cvg:
# result.cvg.start()
testMethod = self._get_test_method()
- if (getattr(self.__class__, "__unittest_skip__", False) or
- getattr(testMethod, "__unittest_skip__", False)):
+ if getattr(self.__class__, "__unittest_skip__", False) or getattr(
+ testMethod, "__unittest_skip__", False
+ ):
# If the class or method was skipped.
try:
- skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
- or getattr(testMethod, '__unittest_skip_why__', ''))
- if hasattr(result, 'addSkip'):
+ skip_why = getattr(self.__class__, "__unittest_skip_why__", "") or getattr(
+ testMethod, "__unittest_skip_why__", ""
+ )
+ if hasattr(result, "addSkip"):
result.addSkip(self, skip_why)
else:
- warnings.warn("TestResult has no addSkip method, skips not reported",
- RuntimeWarning, 2)
+ warnings.warn(
+ "TestResult has no addSkip method, skips not reported", RuntimeWarning, 2
+ )
result.addSuccess(self)
finally:
result.stopTest(self)
return
if runcondition and not runcondition(testMethod):
- return # test is skipped
+ return # test is skipped
result.startTest(self)
try:
if not self.quiet_run(result, self.setUp):
@@ -447,11 +467,10 @@ class TestCase(unittest.TestCase):
generative = isgeneratorfunction(testMethod)
# generative tests
if generative:
- self._proceed_generative(result, testMethod,
- runcondition)
+ self._proceed_generative(result, testMethod, runcondition)
else:
status = self._proceed(result, testMethod)
- success = (status == 0)
+ success = status == 0
if not self.quiet_run(result, self.tearDown):
return
if not generative and success:
@@ -461,19 +480,19 @@ class TestCase(unittest.TestCase):
# result.cvg.stop()
result.stopTest(self)
- def _proceed_generative(self, result: Any, testfunc: Callable, runcondition: Callable = None) -> bool:
+ def _proceed_generative(
+ self, result: Any, testfunc: Callable, runcondition: Callable = None
+ ) -> bool:
# cancel startTest()'s increment
result.testsRun -= 1
success = True
try:
for params in testfunc():
- if runcondition and not runcondition(testfunc,
- skipgenerator=False):
- if not (isinstance(params, InnerTest)
- and runcondition(params)):
+ if runcondition and not runcondition(testfunc, skipgenerator=False):
+ if not (isinstance(params, InnerTest) and runcondition(params)):
continue
if not isinstance(params, (tuple, list)):
- params = (params, )
+ params = (params,)
func = params[0]
args, kwargs = parse_generative_args(params[1:])
# increment test counter manually
@@ -485,9 +504,9 @@ class TestCase(unittest.TestCase):
else:
success = False
# XXX Don't stop anymore if an error occured
- #if status == 2:
+ # if status == 2:
# result.shouldStop = True
- if result.shouldStop: # either on error or on exitfirst + error
+ if result.shouldStop: # either on error or on exitfirst + error
break
except self.failureException:
result.addFailure(self, self.__exc_info())
@@ -500,7 +519,13 @@ class TestCase(unittest.TestCase):
success = False
return success
- def _proceed(self, result: Any, testfunc: Callable, args: Union[List[bool], List[int], Tuple[()]] = (), kwargs: Optional[Dict] = None) -> int:
+ def _proceed(
+ self,
+ result: Any,
+ testfunc: Callable,
+ args: Union[List[bool], List[int], Tuple[()]] = (),
+ kwargs: Optional[Dict] = None,
+ ) -> int:
"""proceed the actual test
returns 0 on success, 1 on failure, 2 on error
@@ -529,39 +554,40 @@ class TestCase(unittest.TestCase):
def innerSkip(self, msg: str = None) -> NoReturn:
"""mark a generative test as skipped for the <msg> reason"""
- msg = msg or 'test was skipped'
+ msg = msg or "test was skipped"
raise InnerTestSkipped(msg)
- if sys.version_info >= (3,2):
+ if sys.version_info >= (3, 2):
assertItemsEqual = unittest.TestCase.assertCountEqual
else:
assertCountEqual = unittest.TestCase.assertItemsEqual
-TestCase.assertItemsEqual = deprecated('assertItemsEqual is deprecated, use assertCountEqual')(
- TestCase.assertItemsEqual)
+
+TestCase.assertItemsEqual = deprecated("assertItemsEqual is deprecated, use assertCountEqual")(
+ TestCase.assertItemsEqual
+)
import doctest
+
class SkippedSuite(unittest.TestSuite):
def test(self):
"""just there to trigger test execution"""
- self.skipped_test('doctest module has no DocTestSuite class')
+ self.skipped_test("doctest module has no DocTestSuite class")
class DocTestFinder(doctest.DocTestFinder):
-
def __init__(self, *args, **kwargs):
- self.skipped = kwargs.pop('skipped', ())
+ self.skipped = kwargs.pop("skipped", ())
doctest.DocTestFinder.__init__(self, *args, **kwargs)
def _get_test(self, obj, name, module, globs, source_lines):
"""override default _get_test method to be able to skip tests
according to skipped attribute's value
"""
- if getattr(obj, '__name__', '') in self.skipped:
+ if getattr(obj, "__name__", "") in self.skipped:
return None
- return doctest.DocTestFinder._get_test(self, obj, name, module,
- globs, source_lines)
+ return doctest.DocTestFinder._get_test(self, obj, name, module, globs, source_lines)
# mypy error: Invalid metaclass 'class_deprecated'
@@ -571,10 +597,11 @@ class DocTest(TestCase, metaclass=class_deprecated): # type: ignore
I don't know how to make unittest.main consider the DocTestSuite instance
without this hack
"""
- __deprecation_warning__ = 'use stdlib doctest module with unittest API directly'
+
+ __deprecation_warning__ = "use stdlib doctest module with unittest API directly"
skipped = ()
- def __call__(self, result=None, runcondition=None, options=None):\
- # pylint: disable=W0613
+
+ def __call__(self, result=None, runcondition=None, options=None): # pylint: disable=W0613
try:
finder = DocTestFinder(skipped=self.skipped)
suite = doctest.DocTestSuite(self.module, test_finder=finder)
@@ -590,6 +617,7 @@ class DocTest(TestCase, metaclass=class_deprecated): # type: ignore
finally:
builtins.__dict__.clear()
builtins.__dict__.update(old_builtins)
+
run = __call__
def test(self):
@@ -607,21 +635,27 @@ class MockConnection:
def cursor(self):
"""Mock cursor method"""
return self
+
def execute(self, query, args=None):
"""Mock execute method"""
- self.received.append( (query, args) )
+ self.received.append((query, args))
+
def fetchone(self):
"""Mock fetchone method"""
return self.results[0]
+
def fetchall(self):
"""Mock fetchall method"""
return self.results
+
def commit(self):
"""Mock commiy method"""
- self.states.append( ('commit', len(self.received)) )
+ self.states.append(("commit", len(self.received)))
+
def rollback(self):
"""Mock rollback method"""
- self.states.append( ('rollback', len(self.received)) )
+ self.states.append(("rollback", len(self.received)))
+
def close(self):
"""Mock close method"""
pass
@@ -629,7 +663,7 @@ class MockConnection:
# mypy error: Name 'Mock' is not defined
# dynamic class created by this class
-def mock_object(**params: Any) -> 'Mock': # type: ignore
+def mock_object(**params: Any) -> "Mock": # type: ignore
"""creates an object using params to set attributes
>>> option = mock_object(verbose=False, index=range(5))
>>> option.verbose
@@ -637,7 +671,7 @@ def mock_object(**params: Any) -> 'Mock': # type: ignore
>>> option.index
[0, 1, 2, 3, 4]
"""
- return type('Mock', (), params)()
+ return type("Mock", (), params)()
def create_files(paths: List[str], chroot: str) -> None:
@@ -664,7 +698,7 @@ def create_files(paths: List[str], chroot: str) -> None:
path = osp.join(chroot, path)
filename = osp.basename(path)
# path is a directory path
- if filename == '':
+ if filename == "":
dirs.add(path)
# path is a filename path
else:
@@ -674,54 +708,69 @@ def create_files(paths: List[str], chroot: str) -> None:
if not osp.isdir(dirpath):
os.makedirs(dirpath)
for filepath in files:
- open(filepath, 'w').close()
+ open(filepath, "w").close()
-class AttrObject: # XXX cf mock_object
+class AttrObject: # XXX cf mock_object
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
+
def tag(*args: str, **kwargs: Any) -> Callable:
"""descriptor adding tag to a function"""
+
def desc(func: Callable) -> Callable:
- assert not hasattr(func, 'tags')
+ assert not hasattr(func, "tags")
# mypy: "Callable[..., Any]" has no attribute "tags"
# dynamic magic attribute
func.tags = Tags(*args, **kwargs) # type: ignore
return func
+
return desc
+
def require_version(version: str) -> Callable:
""" Compare version of python interpreter to the given one. Skip the test
if older.
"""
+
def check_require_version(f: Callable) -> Callable:
- version_elements = version.split('.')
+ version_elements = version.split(".")
try:
compare = tuple([int(v) for v in version_elements])
except ValueError:
- raise ValueError('%s is not a correct version : should be X.Y[.Z].' % version)
+ raise ValueError("%s is not a correct version : should be X.Y[.Z]." % version)
current = sys.version_info[:3]
if current < compare:
+
def new_f(self, *args, **kwargs):
- self.skipTest('Need at least %s version of python. Current version is %s.' % (version, '.'.join([str(element) for element in current])))
+ self.skipTest(
+ "Need at least %s version of python. Current version is %s."
+ % (version, ".".join([str(element) for element in current]))
+ )
+
new_f.__name__ = f.__name__
return new_f
else:
return f
+
return check_require_version
+
def require_module(module: str) -> Callable:
""" Check if the given module is loaded. Skip the test if not.
"""
+
def check_require_module(f: Callable) -> Callable:
try:
__import__(module)
return f
except ImportError:
+
def new_f(self, *args, **kwargs):
- self.skipTest('%s can not be imported.' % module)
+ self.skipTest("%s can not be imported." % module)
+
new_f.__name__ = f.__name__
return new_f
- return check_require_module
+ return check_require_module