summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/plugin/pytestplugin.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/plugin/pytestplugin.py')
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py164
1 files changed, 123 insertions, 41 deletions
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index 3e0630890..c39f9f32e 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -7,6 +7,7 @@ except ImportError:
import argparse
import collections
+from functools import update_wrapper
import inspect
import itertools
import operator
@@ -16,6 +17,13 @@ import sys
import pytest
+try:
+ import typing
+except ImportError:
+ pass
+else:
+ if typing.TYPE_CHECKING:
+ from typing import Sequence
try:
import xdist # noqa
@@ -295,6 +303,49 @@ def getargspec(fn):
return inspect.getargspec(fn)
+def _pytest_fn_decorator(target):
+ """Port of langhelpers.decorator with pytest-specific tricks."""
+
+ from sqlalchemy.util.langhelpers import format_argspec_plus
+ from sqlalchemy.util.compat import inspect_getfullargspec
+
+ def _exec_code_in_env(code, env, fn_name):
+ exec(code, env)
+ return env[fn_name]
+
+ def decorate(fn, add_positional_parameters=()):
+
+ spec = inspect_getfullargspec(fn)
+ if add_positional_parameters:
+ spec.args.extend(add_positional_parameters)
+
+ metadata = dict(target="target", fn="fn", name=fn.__name__)
+ metadata.update(format_argspec_plus(spec, grouped=False))
+ code = (
+ """\
+def %(name)s(%(args)s):
+ return %(target)s(%(fn)s, %(apply_kw)s)
+"""
+ % metadata
+ )
+ decorated = _exec_code_in_env(
+ code, {"target": target, "fn": fn}, fn.__name__
+ )
+ if not add_positional_parameters:
+ decorated.__defaults__ = getattr(fn, "im_func", fn).__defaults__
+ decorated.__wrapped__ = fn
+ return update_wrapper(decorated, fn)
+ else:
+ # this is the pytest hacky part. don't do a full update wrapper
+ # because pytest is really being sneaky about finding the args
+ # for the wrapped function
+ decorated.__module__ = fn.__module__
+ decorated.__name__ = fn.__name__
+ return decorated
+
+ return decorate
+
+
class PytestFixtureFunctions(plugin_base.FixtureFunctions):
def skip_test_exception(self, *arg, **kw):
return pytest.skip.Exception(*arg, **kw)
@@ -326,8 +377,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
argnames = kw.pop("argnames", None)
- exclusion_combinations = []
-
def _filter_exclusions(args):
result = []
gathered_exclusions = []
@@ -337,13 +386,12 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
else:
result.append(a)
- exclusion_combinations.extend(
- [(exclusion, result) for exclusion in gathered_exclusions]
- )
- return result
+ return result, gathered_exclusions
id_ = kw.pop("id_", None)
+ tobuild_pytest_params = []
+ has_exclusions = False
if id_:
_combination_id_fns = self._combination_id_fns
@@ -364,53 +412,87 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
if char in _combination_id_fns
]
- arg_sets = [
- pytest.param(
- *_arg_getter(_filter_exclusions(arg))[1:],
- id="-".join(
- comb_fn(getter(arg)) for getter, comb_fn in fns
+ for arg in arg_sets:
+ if not isinstance(arg, tuple):
+ arg = (arg,)
+
+ fn_params, param_exclusions = _filter_exclusions(arg)
+
+ parameters = _arg_getter(fn_params)[1:]
+
+ if param_exclusions:
+ has_exclusions = True
+
+ tobuild_pytest_params.append(
+ (
+ parameters,
+ param_exclusions,
+ "-".join(
+ comb_fn(getter(arg)) for getter, comb_fn in fns
+ ),
)
)
- for arg in [
- (arg,) if not isinstance(arg, tuple) else arg
- for arg in arg_sets
- ]
- ]
+
else:
- # ensure using pytest.param so that even a 1-arg paramset
- # still needs to be a tuple. otherwise paramtrize tries to
- # interpret a single arg differently than tuple arg
- arg_sets = [
- pytest.param(*_filter_exclusions(arg))
- for arg in [
- (arg,) if not isinstance(arg, tuple) else arg
- for arg in arg_sets
- ]
- ]
+
+ for arg in arg_sets:
+ if not isinstance(arg, tuple):
+ arg = (arg,)
+
+ fn_params, param_exclusions = _filter_exclusions(arg)
+
+ if param_exclusions:
+ has_exclusions = True
+
+ tobuild_pytest_params.append(
+ (fn_params, param_exclusions, None)
+ )
+
+ pytest_params = []
+ for parameters, param_exclusions, id_ in tobuild_pytest_params:
+ if has_exclusions:
+ parameters += (param_exclusions,)
+
+ param = pytest.param(*parameters, id=id_)
+ pytest_params.append(param)
def decorate(fn):
if inspect.isclass(fn):
+ if has_exclusions:
+ raise NotImplementedError(
+ "exclusions not supported for class level combinations"
+ )
if "_sa_parametrize" not in fn.__dict__:
fn._sa_parametrize = []
- fn._sa_parametrize.append((argnames, arg_sets))
+ fn._sa_parametrize.append((argnames, pytest_params))
return fn
else:
if argnames is None:
- _argnames = getargspec(fn).args[1:]
+ _argnames = getargspec(fn).args[1:] # type: Sequence(str)
else:
- _argnames = argnames
-
- if exclusion_combinations:
- for exclusion, combination in exclusion_combinations:
- combination_by_kw = {
- argname: val
- for argname, val in zip(_argnames, combination)
- }
- exclusion = exclusion.with_combination(
- **combination_by_kw
- )
- fn = exclusion(fn)
- return pytest.mark.parametrize(_argnames, arg_sets)(fn)
+ _argnames = re.split(
+ r", *", argnames
+ ) # type: Sequence(str)
+
+ if has_exclusions:
+ _argnames += ["_exclusions"]
+
+ @_pytest_fn_decorator
+ def check_exclusions(fn, *args, **kw):
+ _exclusions = args[-1]
+ if _exclusions:
+ exlu = exclusions.compound().add(*_exclusions)
+ fn = exlu(fn)
+ return fn(*args[0:-1], **kw)
+
+ def process_metadata(spec):
+ spec.args.append("_exclusions")
+
+ fn = check_exclusions(
+ fn, add_positional_parameters=("_exclusions",)
+ )
+
+ return pytest.mark.parametrize(_argnames, pytest_params)(fn)
return decorate