diff options
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/asyncio.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/engines.py | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/util/_concurrency_py3k.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/util/concurrency.py | 6 | ||||
-rw-r--r-- | test/base/test_concurrency_py3k.py | 7 | ||||
-rw-r--r-- | test/ext/asyncio/test_engine_py3k.py | 9 | ||||
-rw-r--r-- | test/ext/asyncio/test_session_py3k.py | 4 |
10 files changed, 60 insertions, 24 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 16edcc2b2..93adaf78a 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -41,7 +41,7 @@ def create_async_engine(*arg, **kw): class AsyncConnectable: - __slots__ = "_slots_dispatch" + __slots__ = "_slots_dispatch", "__weakref__" @util.create_proxy_methods( diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py index 52386d33e..bdf730a4c 100644 --- a/lib/sqlalchemy/testing/asyncio.py +++ b/lib/sqlalchemy/testing/asyncio.py @@ -22,12 +22,17 @@ import inspect from . import config from ..util.concurrency import _util_async_run +from ..util.concurrency import _util_async_run_coroutine_function # may be set to False if the # --disable-asyncio flag is passed to the test runner. ENABLE_ASYNCIO = True +def _run_coroutine_function(fn, *args, **kwargs): + return _util_async_run_coroutine_function(fn, *args, **kwargs) + + def _assume_async(fn, *args, **kwargs): """Run a function in an asyncio loop unconditionally. diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index d0a1bc0d0..4d4563afb 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -97,7 +97,10 @@ class ConnectionKiller(object): self.conns = set() for rec in list(self.testing_engines): - rec.dispose() + if hasattr(rec, "sync_engine"): + rec.sync_engine.dispose() + else: + rec.dispose() def assert_all_closed(self): for rec in self.proxy_refs: @@ -236,10 +239,12 @@ def reconnecting_engine(url=None, options=None): return engine -def testing_engine(url=None, options=None, future=False): +def testing_engine(url=None, options=None, future=False, asyncio=False): """Produce an engine configured by --options with optional overrides.""" - if future or config.db and config.db._is_future: + if asyncio: + from sqlalchemy.ext.asyncio import create_async_engine as create_engine + elif future or config.db and config.db._is_future: from sqlalchemy.future import create_engine else: from sqlalchemy import create_engine @@ -263,7 +268,10 @@ def testing_engine(url=None, options=None, future=False): default_opt.update(options) engine = create_engine(url, **options) - engine._has_events = True # enable event blocks, helps with profiling + if asyncio: + engine.sync_engine._has_events = True + else: + engine._has_events = True # enable event blocks, helps with profiling if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index a52fdd196..0ede25176 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -48,11 +48,6 @@ class TestBase(object): # skipped. __skip_if__ = None - # If this class should be wrapped in asyncio compatibility functions - # when using an async engine. This should be set to False only for tests - # that use the asyncio features of sqlalchemy directly - __asyncio_wrap__ = True - def assert_(self, val, msg=None): assert val, msg @@ -95,12 +90,6 @@ class TestBase(object): # engines.drop_all_tables(metadata, config.db) -class AsyncTestBase(TestBase): - """Mixin marking a test as using its own explicit asyncio patterns.""" - - __asyncio_wrap__ = False - - class FutureEngineMixin(object): @classmethod def setup_class(cls): diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 6be64aa61..46468a07d 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -255,7 +255,7 @@ def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(name, obj): from sqlalchemy.testing import config - if config.any_async and getattr(obj, "__asyncio_wrap__", True): + if config.any_async: obj = _apply_maybe_async(obj) ctor = getattr(pytest.Class, "from_parent", pytest.Class) @@ -277,6 +277,13 @@ def pytest_pycollect_makeitem(collector, name, obj): return [] +def _is_wrapped_coroutine_function(fn): + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + + return inspect.iscoroutinefunction(fn) + + def _apply_maybe_async(obj, recurse=True): from sqlalchemy.testing import asyncio @@ -286,6 +293,7 @@ def _apply_maybe_async(obj, recurse=True): (callable(value) or isinstance(value, classmethod)) and not getattr(value, "_maybe_async_applied", False) and (name.startswith("test_") or name in setup_names) + and not _is_wrapped_coroutine_function(value) ): is_classmethod = False if isinstance(value, classmethod): @@ -656,6 +664,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): @_pytest_fn_decorator def decorate(fn, *args, **kwargs): - asyncio._assume_async(fn, *args, **kwargs) + asyncio._run_coroutine_function(fn, *args, **kwargs) return decorate(fn) diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 6042e4395..663d3e0f4 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -136,6 +136,18 @@ class AsyncAdaptedLock: self.mutex.release() +def _util_async_run_coroutine_function(fn, *args, **kwargs): + """for test suite/ util only""" + + loop = asyncio.get_event_loop() + if loop.is_running(): + raise Exception( + "for async run coroutine we expect that no greenlet or event " + "loop is running when we start out" + ) + return loop.run_until_complete(fn(*args, **kwargs)) + + def _util_async_run(fn, *args, **kwargs): """for test suite/ util only""" diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index 7b4ff6ba4..c44efba62 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -14,6 +14,9 @@ if compat.py3k: from ._concurrency_py3k import greenlet_spawn from ._concurrency_py3k import AsyncAdaptedLock from ._concurrency_py3k import _util_async_run # noqa F401 + from ._concurrency_py3k import ( + _util_async_run_coroutine_function, + ) # noqa F401, E501 from ._concurrency_py3k import asyncio # noqa F401 if not have_greenlet: @@ -42,3 +45,6 @@ if not have_greenlet: def _util_async_run(fn, *arg, **kw): # noqa F81 return fn(*arg, **kw) + + def _util_async_run_coroutine_function(fn, *arg, **kw): # noqa F81 + _not_implemented() diff --git a/test/base/test_concurrency_py3k.py b/test/base/test_concurrency_py3k.py index 2cc2075bc..e7ae8c9ad 100644 --- a/test/base/test_concurrency_py3k.py +++ b/test/base/test_concurrency_py3k.py @@ -26,7 +26,7 @@ def go(*fns): return sum(await_only(fn()) for fn in fns) -class TestAsyncioCompat(fixtures.AsyncTestBase): +class TestAsyncioCompat(fixtures.TestBase): @async_test async def test_ok(self): @@ -53,7 +53,8 @@ class TestAsyncioCompat(fixtures.AsyncTestBase): to_await = run1() await_fallback(to_await) - def test_await_only_no_greenlet(self): + @async_test + async def test_await_only_no_greenlet(self): to_await = run1() with expect_raises_message( exc.InvalidRequestError, @@ -62,7 +63,7 @@ class TestAsyncioCompat(fixtures.AsyncTestBase): await_only(to_await) # ensure no warning - await_fallback(to_await) + await greenlet_spawn(await_fallback, to_await) @async_test async def test_await_fallback_error(self): diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index cd1e16ed9..7dae1411e 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import engine as _async_engine from sqlalchemy.ext.asyncio import exc as asyncio_exc from sqlalchemy.testing import async_test from sqlalchemy.testing import combinations +from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message @@ -32,7 +33,7 @@ class EngineFixture(fixtures.TablesTest): @testing.fixture def async_engine(self): - return create_async_engine(testing.db.url) + return engines.testing_engine(asyncio=True) @classmethod def define_tables(cls, metadata): @@ -55,6 +56,12 @@ class EngineFixture(fixtures.TablesTest): class AsyncEngineTest(EngineFixture): __backend__ = True + @testing.fails("the failure is the test") + @async_test + async def test_we_are_definitely_running_async_tests(self, async_engine): + async with async_engine.connect() as conn: + eq_(await conn.scalar(text("select 1")), 2) + def test_proxied_attrs_engine(self, async_engine): sync_engine = async_engine.sync_engine diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index 37e1b807b..dbe84e82c 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -5,10 +5,10 @@ from sqlalchemy import select from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import selectinload from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import async_test +from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import is_ from sqlalchemy.testing import mock @@ -24,7 +24,7 @@ class AsyncFixture(_fixtures.FixtureTest): @testing.fixture def async_engine(self): - return create_async_engine(testing.db.url) + return engines.testing_engine(asyncio=True) @testing.fixture def async_session(self, async_engine): |