diff options
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r-- | lib/sqlalchemy/testing/asyncio.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/engines.py | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 12 |
4 files changed, 27 insertions, 17 deletions
diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py index 52386d33e..bdf730a4c 100644 --- a/lib/sqlalchemy/testing/asyncio.py +++ b/lib/sqlalchemy/testing/asyncio.py @@ -22,12 +22,17 @@ import inspect from . import config from ..util.concurrency import _util_async_run +from ..util.concurrency import _util_async_run_coroutine_function # may be set to False if the # --disable-asyncio flag is passed to the test runner. ENABLE_ASYNCIO = True +def _run_coroutine_function(fn, *args, **kwargs): + return _util_async_run_coroutine_function(fn, *args, **kwargs) + + def _assume_async(fn, *args, **kwargs): """Run a function in an asyncio loop unconditionally. diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index d0a1bc0d0..4d4563afb 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -97,7 +97,10 @@ class ConnectionKiller(object): self.conns = set() for rec in list(self.testing_engines): - rec.dispose() + if hasattr(rec, "sync_engine"): + rec.sync_engine.dispose() + else: + rec.dispose() def assert_all_closed(self): for rec in self.proxy_refs: @@ -236,10 +239,12 @@ def reconnecting_engine(url=None, options=None): return engine -def testing_engine(url=None, options=None, future=False): +def testing_engine(url=None, options=None, future=False, asyncio=False): """Produce an engine configured by --options with optional overrides.""" - if future or config.db and config.db._is_future: + if asyncio: + from sqlalchemy.ext.asyncio import create_async_engine as create_engine + elif future or config.db and config.db._is_future: from sqlalchemy.future import create_engine else: from sqlalchemy import create_engine @@ -263,7 +268,10 @@ def testing_engine(url=None, options=None, future=False): default_opt.update(options) engine = create_engine(url, **options) - engine._has_events = True # enable event blocks, helps with profiling + if asyncio: + engine.sync_engine._has_events = True + else: + engine._has_events = True # enable event blocks, helps with profiling if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index a52fdd196..0ede25176 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -48,11 +48,6 @@ class TestBase(object): # skipped. __skip_if__ = None - # If this class should be wrapped in asyncio compatibility functions - # when using an async engine. This should be set to False only for tests - # that use the asyncio features of sqlalchemy directly - __asyncio_wrap__ = True - def assert_(self, val, msg=None): assert val, msg @@ -95,12 +90,6 @@ class TestBase(object): # engines.drop_all_tables(metadata, config.db) -class AsyncTestBase(TestBase): - """Mixin marking a test as using its own explicit asyncio patterns.""" - - __asyncio_wrap__ = False - - class FutureEngineMixin(object): @classmethod def setup_class(cls): diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 6be64aa61..46468a07d 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -255,7 +255,7 @@ def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(name, obj): from sqlalchemy.testing import config - if config.any_async and getattr(obj, "__asyncio_wrap__", True): + if config.any_async: obj = _apply_maybe_async(obj) ctor = getattr(pytest.Class, "from_parent", pytest.Class) @@ -277,6 +277,13 @@ def pytest_pycollect_makeitem(collector, name, obj): return [] +def _is_wrapped_coroutine_function(fn): + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + + return inspect.iscoroutinefunction(fn) + + def _apply_maybe_async(obj, recurse=True): from sqlalchemy.testing import asyncio @@ -286,6 +293,7 @@ def _apply_maybe_async(obj, recurse=True): (callable(value) or isinstance(value, classmethod)) and not getattr(value, "_maybe_async_applied", False) and (name.startswith("test_") or name in setup_names) + and not _is_wrapped_coroutine_function(value) ): is_classmethod = False if isinstance(value, classmethod): @@ -656,6 +664,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): @_pytest_fn_decorator def decorate(fn, *args, **kwargs): - asyncio._assume_async(fn, *args, **kwargs) + asyncio._run_coroutine_function(fn, *args, **kwargs) return decorate(fn) |