summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/plugin/pytestplugin.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/plugin/pytestplugin.py')
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py34
1 files changed, 32 insertions, 2 deletions
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index de0f8862a..2f7df97fa 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -315,6 +315,7 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
ids for parameter sets are derived using an optional template.
"""
+ from sqlalchemy.testing import exclusions
if sys.version_info.major == 3:
if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
@@ -325,6 +326,22 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
argnames = kw.pop("argnames", None)
+ exclusion_combinations = []
+
+ def _filter_exclusions(args):
+ result = []
+ gathered_exclusions = []
+ for a in args:
+ if isinstance(a, exclusions.compound):
+ gathered_exclusions.append(a)
+ else:
+ result.append(a)
+
+ exclusion_combinations.extend(
+ [(exclusion, result) for exclusion in gathered_exclusions]
+ )
+ return result
+
id_ = kw.pop("id_", None)
if id_:
@@ -348,7 +365,7 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
]
arg_sets = [
pytest.param(
- *_arg_getter(arg)[1:],
+ *_arg_getter(_filter_exclusions(arg))[1:],
id="-".join(
comb_fn(getter(arg)) for getter, comb_fn in fns
)
@@ -359,7 +376,9 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
# ensure using pytest.param so that even a 1-arg paramset
# still needs to be a tuple. otherwise paramtrize tries to
# interpret a single arg differently than tuple arg
- arg_sets = [pytest.param(*arg) for arg in arg_sets]
+ arg_sets = [
+ pytest.param(*_filter_exclusions(arg)) for arg in arg_sets
+ ]
def decorate(fn):
if inspect.isclass(fn):
@@ -372,6 +391,17 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
_argnames = getargspec(fn).args[1:]
else:
_argnames = argnames
+
+ if exclusion_combinations:
+ for exclusion, combination in exclusion_combinations:
+ combination_by_kw = {
+ argname: val
+ for argname, val in zip(_argnames, combination)
+ }
+ exclusion = exclusion.with_combination(
+ **combination_by_kw
+ )
+ fn = exclusion(fn)
return pytest.mark.parametrize(_argnames, arg_sets)(fn)
return decorate