diff options
Diffstat (limited to 'lib/sqlalchemy/testing/plugin/pytestplugin.py')
-rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 160 |
1 files changed, 127 insertions, 33 deletions
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 644ea6dc2..6be64aa61 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -26,11 +26,6 @@ else: from typing import Sequence try: - import asyncio -except ImportError: - pass - -try: import xdist # noqa has_xdist = True @@ -126,11 +121,15 @@ def collect_types_fixture(): def pytest_sessionstart(session): - plugin_base.post_begin() + from sqlalchemy.testing import asyncio + + asyncio._assume_async(plugin_base.post_begin) def pytest_sessionfinish(session): - plugin_base.final_process_cleanup() + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup) if session.config.option.dump_pyannotate: from pyannotate_runtime import collect_types @@ -162,23 +161,31 @@ if has_xdist: import uuid def pytest_configure_node(node): + from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio + # the master for each node fills workerinput dictionary # which pytest-xdist will transfer to the subprocess plugin_base.memoize_important_follower_config(node.workerinput) node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12] - from sqlalchemy.testing import provision - provision.create_follower_db(node.workerinput["follower_ident"]) + asyncio._maybe_async_provisioning( + provision.create_follower_db, node.workerinput["follower_ident"] + ) def pytest_testnodedown(node, error): from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio - provision.drop_follower_db(node.workerinput["follower_ident"]) + asyncio._maybe_async_provisioning( + provision.drop_follower_db, node.workerinput["follower_ident"] + ) def pytest_collection_modifyitems(session, config, items): + # look for all those classes that specify __backend__ and # expand them out into per-database test cases. @@ -189,6 +196,8 @@ def pytest_collection_modifyitems(session, config, items): # it's to suit the rather odd use case here which is that we are adding # new classes to a module on the fly. + from sqlalchemy.testing import asyncio + rebuilt_items = collections.defaultdict( lambda: collections.defaultdict(list) ) @@ -201,20 +210,26 @@ def pytest_collection_modifyitems(session, config, items): ] test_classes = set(item.parent for item in items) - for test_class in test_classes: - for sub_cls in plugin_base.generate_sub_tests( - test_class.cls, test_class.parent.module - ): - if sub_cls is not test_class.cls: - per_cls_dict = rebuilt_items[test_class.cls] - # support pytest 5.4.0 and above pytest.Class.from_parent - ctor = getattr(pytest.Class, "from_parent", pytest.Class) - for inst in ctor( - name=sub_cls.__name__, parent=test_class.parent.parent - ).collect(): - for t in inst.collect(): - per_cls_dict[t.name].append(t) + def setup_test_classes(): + for test_class in test_classes: + for sub_cls in plugin_base.generate_sub_tests( + test_class.cls, test_class.parent.module + ): + if sub_cls is not test_class.cls: + per_cls_dict = rebuilt_items[test_class.cls] + + # support pytest 5.4.0 and above pytest.Class.from_parent + ctor = getattr(pytest.Class, "from_parent", pytest.Class) + for inst in ctor( + name=sub_cls.__name__, parent=test_class.parent.parent + ).collect(): + for t in inst.collect(): + per_cls_dict[t.name].append(t) + + # class requirements will sometimes need to access the DB to check + # capabilities, so need to do this for async + asyncio._maybe_async_provisioning(setup_test_classes) newitems = [] for item in items: @@ -238,6 +253,10 @@ def pytest_collection_modifyitems(session, config, items): def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(name, obj): + from sqlalchemy.testing import config + + if config.any_async and getattr(obj, "__asyncio_wrap__", True): + obj = _apply_maybe_async(obj) ctor = getattr(pytest.Class, "from_parent", pytest.Class) @@ -258,6 +277,38 @@ def pytest_pycollect_makeitem(collector, name, obj): return [] +def _apply_maybe_async(obj, recurse=True): + from sqlalchemy.testing import asyncio + + setup_names = {"setup", "setup_class", "teardown", "teardown_class"} + for name, value in vars(obj).items(): + if ( + (callable(value) or isinstance(value, classmethod)) + and not getattr(value, "_maybe_async_applied", False) + and (name.startswith("test_") or name in setup_names) + ): + is_classmethod = False + if isinstance(value, classmethod): + value = value.__func__ + is_classmethod = True + + @_pytest_fn_decorator + def make_async(fn, *args, **kwargs): + return asyncio._maybe_async(fn, *args, **kwargs) + + do_async = make_async(value) + if is_classmethod: + do_async = classmethod(do_async) + do_async._maybe_async_applied = True + + setattr(obj, name, do_async) + if recurse: + for cls in obj.mro()[1:]: + if cls != object: + _apply_maybe_async(cls, False) + return obj + + _current_class = None @@ -297,6 +348,8 @@ def _parametrize_cls(module, cls): def pytest_runtest_setup(item): + from sqlalchemy.testing import asyncio + # here we seem to get called only based on what we collected # in pytest_collection_modifyitems. So to do class-based stuff # we have to tear that out. @@ -307,7 +360,7 @@ def pytest_runtest_setup(item): # ... so we're doing a little dance here to figure it out... if _current_class is None: - class_setup(item.parent.parent) + asyncio._maybe_async(class_setup, item.parent.parent) _current_class = item.parent.parent # this is needed for the class-level, to ensure that the @@ -315,20 +368,22 @@ def pytest_runtest_setup(item): # class-level teardown... def finalize(): global _current_class - class_teardown(item.parent.parent) + asyncio._maybe_async(class_teardown, item.parent.parent) _current_class = None item.parent.parent.addfinalizer(finalize) - test_setup(item) + asyncio._maybe_async(test_setup, item) def pytest_runtest_teardown(item): + from sqlalchemy.testing import asyncio + # ...but this works better as the hook here rather than # using a finalizer, as the finalizer seems to get in the way # of the test reporting failures correctly (you get a bunch of # pytest assertion stuff instead) - test_teardown(item) + asyncio._maybe_async(test_teardown, item) def test_setup(item): @@ -342,7 +397,9 @@ def test_teardown(item): def class_setup(item): - plugin_base.start_test_class(item.cls) + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls) def class_teardown(item): @@ -372,17 +429,19 @@ def _pytest_fn_decorator(target): if add_positional_parameters: spec.args.extend(add_positional_parameters) - metadata = dict(target="target", fn="__fn", name=fn.__name__) + metadata = dict( + __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__ + ) metadata.update(format_argspec_plus(spec, grouped=False)) code = ( """\ def %(name)s(%(args)s): - return %(target)s(%(fn)s, %(apply_kw)s) + return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s) """ % metadata ) decorated = _exec_code_in_env( - code, {"target": target, "__fn": fn}, fn.__name__ + code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__ ) if not add_positional_parameters: decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ @@ -554,14 +613,49 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): return pytest.param(*parameters[1:], id=ident) def fixture(self, *arg, **kw): - return pytest.fixture(*arg, **kw) + from sqlalchemy.testing import config + from sqlalchemy.testing import asyncio + + # wrapping pytest.fixture function. determine if + # decorator was called as @fixture or @fixture(). + if len(arg) > 0 and callable(arg[0]): + # was called as @fixture(), we have the function to wrap. + fn = arg[0] + arg = arg[1:] + else: + # was called as @fixture, don't have the function yet. + fn = None + + # create a pytest.fixture marker. because the fn is not being + # passed, this is always a pytest.FixtureFunctionMarker() + # object (or whatever pytest is calling it when you read this) + # that is waiting for a function. + fixture = pytest.fixture(*arg, **kw) + + # now apply wrappers to the function, including fixture itself + + def wrap(fn): + if config.any_async: + fn = asyncio._maybe_async_wrapper(fn) + # other wrappers may be added here + + # now apply FixtureFunctionMarker + fn = fixture(fn) + return fn + + if fn: + return wrap(fn) + else: + return wrap def get_current_test_name(self): return os.environ.get("PYTEST_CURRENT_TEST") def async_test(self, fn): + from sqlalchemy.testing import asyncio + @_pytest_fn_decorator def decorate(fn, *args, **kwargs): - asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs)) + asyncio._assume_async(fn, *args, **kwargs) return decorate(fn) |