diff options
Diffstat (limited to 'lib/sqlalchemy/testing/plugin/pytestplugin.py')
-rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 015598952..3df239afa 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -26,6 +26,11 @@ else: from typing import Sequence try: + import asyncio +except ImportError: + pass + +try: import xdist # noqa has_xdist = True @@ -101,6 +106,24 @@ def pytest_configure(config): plugin_base.set_fixture_functions(PytestFixtureFunctions) + if config.option.dump_pyannotate: + global DUMP_PYANNOTATE + DUMP_PYANNOTATE = True + + +DUMP_PYANNOTATE = False + + +@pytest.fixture(autouse=True) +def collect_types_fixture(): + if DUMP_PYANNOTATE: + from pyannotate_runtime import collect_types + + collect_types.start() + yield + if DUMP_PYANNOTATE: + collect_types.stop() + def pytest_sessionstart(session): plugin_base.post_begin() @@ -109,6 +132,31 @@ def pytest_sessionstart(session): def pytest_sessionfinish(session): plugin_base.final_process_cleanup() + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + collect_types.dump_stats(session.config.option.dump_pyannotate) + + +def pytest_collection_finish(session): + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + lib_sqlalchemy = os.path.abspath("lib/sqlalchemy") + + def _filter(filename): + filename = os.path.normpath(os.path.abspath(filename)) + if "lib/sqlalchemy" not in os.path.commonpath( + [filename, lib_sqlalchemy] + ): + return None + if "testing" in filename: + return None + + return filename + + collect_types.init_types_collection(filter_filename=_filter) + if has_xdist: import uuid @@ -518,3 +566,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): def get_current_test_name(self): return os.environ.get("PYTEST_CURRENT_TEST") + + def async_test(self, fn): + @_pytest_fn_decorator + def decorate(fn, *args, **kwargs): + asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs)) + + return decorate(fn) |