summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/asyncio.py5
-rw-r--r--lib/sqlalchemy/testing/engines.py16
-rw-r--r--lib/sqlalchemy/testing/fixtures.py11
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py12
4 files changed, 27 insertions, 17 deletions
diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py
index 52386d33e..bdf730a4c 100644
--- a/lib/sqlalchemy/testing/asyncio.py
+++ b/lib/sqlalchemy/testing/asyncio.py
@@ -22,12 +22,17 @@ import inspect
from . import config
from ..util.concurrency import _util_async_run
+from ..util.concurrency import _util_async_run_coroutine_function
# may be set to False if the
# --disable-asyncio flag is passed to the test runner.
ENABLE_ASYNCIO = True
+def _run_coroutine_function(fn, *args, **kwargs):
+ return _util_async_run_coroutine_function(fn, *args, **kwargs)
+
+
def _assume_async(fn, *args, **kwargs):
"""Run a function in an asyncio loop unconditionally.
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index d0a1bc0d0..4d4563afb 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -97,7 +97,10 @@ class ConnectionKiller(object):
self.conns = set()
for rec in list(self.testing_engines):
- rec.dispose()
+ if hasattr(rec, "sync_engine"):
+ rec.sync_engine.dispose()
+ else:
+ rec.dispose()
def assert_all_closed(self):
for rec in self.proxy_refs:
@@ -236,10 +239,12 @@ def reconnecting_engine(url=None, options=None):
return engine
-def testing_engine(url=None, options=None, future=False):
+def testing_engine(url=None, options=None, future=False, asyncio=False):
"""Produce an engine configured by --options with optional overrides."""
- if future or config.db and config.db._is_future:
+ if asyncio:
+ from sqlalchemy.ext.asyncio import create_async_engine as create_engine
+ elif future or config.db and config.db._is_future:
from sqlalchemy.future import create_engine
else:
from sqlalchemy import create_engine
@@ -263,7 +268,10 @@ def testing_engine(url=None, options=None, future=False):
default_opt.update(options)
engine = create_engine(url, **options)
- engine._has_events = True # enable event blocks, helps with profiling
+ if asyncio:
+ engine.sync_engine._has_events = True
+ else:
+ engine._has_events = True # enable event blocks, helps with profiling
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index a52fdd196..0ede25176 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -48,11 +48,6 @@ class TestBase(object):
# skipped.
__skip_if__ = None
- # If this class should be wrapped in asyncio compatibility functions
- # when using an async engine. This should be set to False only for tests
- # that use the asyncio features of sqlalchemy directly
- __asyncio_wrap__ = True
-
def assert_(self, val, msg=None):
assert val, msg
@@ -95,12 +90,6 @@ class TestBase(object):
# engines.drop_all_tables(metadata, config.db)
-class AsyncTestBase(TestBase):
- """Mixin marking a test as using its own explicit asyncio patterns."""
-
- __asyncio_wrap__ = False
-
-
class FutureEngineMixin(object):
@classmethod
def setup_class(cls):
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index 6be64aa61..46468a07d 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -255,7 +255,7 @@ def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
from sqlalchemy.testing import config
- if config.any_async and getattr(obj, "__asyncio_wrap__", True):
+ if config.any_async:
obj = _apply_maybe_async(obj)
ctor = getattr(pytest.Class, "from_parent", pytest.Class)
@@ -277,6 +277,13 @@ def pytest_pycollect_makeitem(collector, name, obj):
return []
+def _is_wrapped_coroutine_function(fn):
+ while hasattr(fn, "__wrapped__"):
+ fn = fn.__wrapped__
+
+ return inspect.iscoroutinefunction(fn)
+
+
def _apply_maybe_async(obj, recurse=True):
from sqlalchemy.testing import asyncio
@@ -286,6 +293,7 @@ def _apply_maybe_async(obj, recurse=True):
(callable(value) or isinstance(value, classmethod))
and not getattr(value, "_maybe_async_applied", False)
and (name.startswith("test_") or name in setup_names)
+ and not _is_wrapped_coroutine_function(value)
):
is_classmethod = False
if isinstance(value, classmethod):
@@ -656,6 +664,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
@_pytest_fn_decorator
def decorate(fn, *args, **kwargs):
- asyncio._assume_async(fn, *args, **kwargs)
+ asyncio._run_coroutine_function(fn, *args, **kwargs)
return decorate(fn)