summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/plugin/pytestplugin.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/plugin/pytestplugin.py')
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py55
1 files changed, 55 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index 015598952..3df239afa 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -26,6 +26,11 @@ else:
from typing import Sequence
try:
+ import asyncio
+except ImportError:
+ pass
+
+try:
import xdist # noqa
has_xdist = True
@@ -101,6 +106,24 @@ def pytest_configure(config):
plugin_base.set_fixture_functions(PytestFixtureFunctions)
+ if config.option.dump_pyannotate:
+ global DUMP_PYANNOTATE
+ DUMP_PYANNOTATE = True
+
+
+DUMP_PYANNOTATE = False
+
+
+@pytest.fixture(autouse=True)
+def collect_types_fixture():
+ if DUMP_PYANNOTATE:
+ from pyannotate_runtime import collect_types
+
+ collect_types.start()
+ yield
+ if DUMP_PYANNOTATE:
+ collect_types.stop()
+
def pytest_sessionstart(session):
plugin_base.post_begin()
@@ -109,6 +132,31 @@ def pytest_sessionstart(session):
def pytest_sessionfinish(session):
plugin_base.final_process_cleanup()
+ if session.config.option.dump_pyannotate:
+ from pyannotate_runtime import collect_types
+
+ collect_types.dump_stats(session.config.option.dump_pyannotate)
+
+
+def pytest_collection_finish(session):
+ if session.config.option.dump_pyannotate:
+ from pyannotate_runtime import collect_types
+
+ lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
+
+ def _filter(filename):
+ filename = os.path.normpath(os.path.abspath(filename))
+ if "lib/sqlalchemy" not in os.path.commonpath(
+ [filename, lib_sqlalchemy]
+ ):
+ return None
+ if "testing" in filename:
+ return None
+
+ return filename
+
+ collect_types.init_types_collection(filter_filename=_filter)
+
if has_xdist:
import uuid
@@ -518,3 +566,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
def get_current_test_name(self):
return os.environ.get("PYTEST_CURRENT_TEST")
+
+ def async_test(self, fn):
+ @_pytest_fn_decorator
+ def decorate(fn, *args, **kwargs):
+ asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs))
+
+ return decorate(fn)