diff options
Diffstat (limited to 'lib/sqlalchemy/testing/plugin/pytestplugin.py')
-rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 164 |
1 files changed, 123 insertions, 41 deletions
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 3e0630890..c39f9f32e 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -7,6 +7,7 @@ except ImportError: import argparse import collections +from functools import update_wrapper import inspect import itertools import operator @@ -16,6 +17,13 @@ import sys import pytest +try: + import typing +except ImportError: + pass +else: + if typing.TYPE_CHECKING: + from typing import Sequence try: import xdist # noqa @@ -295,6 +303,49 @@ def getargspec(fn): return inspect.getargspec(fn) +def _pytest_fn_decorator(target): + """Port of langhelpers.decorator with pytest-specific tricks.""" + + from sqlalchemy.util.langhelpers import format_argspec_plus + from sqlalchemy.util.compat import inspect_getfullargspec + + def _exec_code_in_env(code, env, fn_name): + exec(code, env) + return env[fn_name] + + def decorate(fn, add_positional_parameters=()): + + spec = inspect_getfullargspec(fn) + if add_positional_parameters: + spec.args.extend(add_positional_parameters) + + metadata = dict(target="target", fn="fn", name=fn.__name__) + metadata.update(format_argspec_plus(spec, grouped=False)) + code = ( + """\ +def %(name)s(%(args)s): + return %(target)s(%(fn)s, %(apply_kw)s) +""" + % metadata + ) + decorated = _exec_code_in_env( + code, {"target": target, "fn": fn}, fn.__name__ + ) + if not add_positional_parameters: + decorated.__defaults__ = getattr(fn, "im_func", fn).__defaults__ + decorated.__wrapped__ = fn + return update_wrapper(decorated, fn) + else: + # this is the pytest hacky part. don't do a full update wrapper + # because pytest is really being sneaky about finding the args + # for the wrapped function + decorated.__module__ = fn.__module__ + decorated.__name__ = fn.__name__ + return decorated + + return decorate + + class PytestFixtureFunctions(plugin_base.FixtureFunctions): def skip_test_exception(self, *arg, **kw): return pytest.skip.Exception(*arg, **kw) @@ -326,8 +377,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): argnames = kw.pop("argnames", None) - exclusion_combinations = [] - def _filter_exclusions(args): result = [] gathered_exclusions = [] @@ -337,13 +386,12 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): else: result.append(a) - exclusion_combinations.extend( - [(exclusion, result) for exclusion in gathered_exclusions] - ) - return result + return result, gathered_exclusions id_ = kw.pop("id_", None) + tobuild_pytest_params = [] + has_exclusions = False if id_: _combination_id_fns = self._combination_id_fns @@ -364,53 +412,87 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): if char in _combination_id_fns ] - arg_sets = [ - pytest.param( - *_arg_getter(_filter_exclusions(arg))[1:], - id="-".join( - comb_fn(getter(arg)) for getter, comb_fn in fns + for arg in arg_sets: + if not isinstance(arg, tuple): + arg = (arg,) + + fn_params, param_exclusions = _filter_exclusions(arg) + + parameters = _arg_getter(fn_params)[1:] + + if param_exclusions: + has_exclusions = True + + tobuild_pytest_params.append( + ( + parameters, + param_exclusions, + "-".join( + comb_fn(getter(arg)) for getter, comb_fn in fns + ), ) ) - for arg in [ - (arg,) if not isinstance(arg, tuple) else arg - for arg in arg_sets - ] - ] + else: - # ensure using pytest.param so that even a 1-arg paramset - # still needs to be a tuple. otherwise paramtrize tries to - # interpret a single arg differently than tuple arg - arg_sets = [ - pytest.param(*_filter_exclusions(arg)) - for arg in [ - (arg,) if not isinstance(arg, tuple) else arg - for arg in arg_sets - ] - ] + + for arg in arg_sets: + if not isinstance(arg, tuple): + arg = (arg,) + + fn_params, param_exclusions = _filter_exclusions(arg) + + if param_exclusions: + has_exclusions = True + + tobuild_pytest_params.append( + (fn_params, param_exclusions, None) + ) + + pytest_params = [] + for parameters, param_exclusions, id_ in tobuild_pytest_params: + if has_exclusions: + parameters += (param_exclusions,) + + param = pytest.param(*parameters, id=id_) + pytest_params.append(param) def decorate(fn): if inspect.isclass(fn): + if has_exclusions: + raise NotImplementedError( + "exclusions not supported for class level combinations" + ) if "_sa_parametrize" not in fn.__dict__: fn._sa_parametrize = [] - fn._sa_parametrize.append((argnames, arg_sets)) + fn._sa_parametrize.append((argnames, pytest_params)) return fn else: if argnames is None: - _argnames = getargspec(fn).args[1:] + _argnames = getargspec(fn).args[1:] # type: Sequence(str) else: - _argnames = argnames - - if exclusion_combinations: - for exclusion, combination in exclusion_combinations: - combination_by_kw = { - argname: val - for argname, val in zip(_argnames, combination) - } - exclusion = exclusion.with_combination( - **combination_by_kw - ) - fn = exclusion(fn) - return pytest.mark.parametrize(_argnames, arg_sets)(fn) + _argnames = re.split( + r", *", argnames + ) # type: Sequence(str) + + if has_exclusions: + _argnames += ["_exclusions"] + + @_pytest_fn_decorator + def check_exclusions(fn, *args, **kw): + _exclusions = args[-1] + if _exclusions: + exlu = exclusions.compound().add(*_exclusions) + fn = exlu(fn) + return fn(*args[0:-1], **kw) + + def process_metadata(spec): + spec.args.append("_exclusions") + + fn = check_exclusions( + fn, add_positional_parameters=("_exclusions",) + ) + + return pytest.mark.parametrize(_argnames, pytest_params)(fn) return decorate |