diff options
Diffstat (limited to 'test/ext/asyncio/test_engine_py3k.py')
-rw-r--r-- | test/ext/asyncio/test_engine_py3k.py | 139 |
1 files changed, 138 insertions, 1 deletions
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 820c82bca..18e55ff92 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -17,8 +17,10 @@ from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import engine as _async_engine from sqlalchemy.ext.asyncio import exc as asyncio_exc from sqlalchemy.pool import AsyncAdaptedQueuePool +from sqlalchemy.testing import assertions from sqlalchemy.testing import async_test from sqlalchemy.testing import combinations +from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises @@ -34,7 +36,133 @@ from sqlalchemy.testing import ne_ from sqlalchemy.util.concurrency import greenlet_spawn -class EngineFixture(fixtures.TablesTest): +class AsyncFixture: + @config.fixture( + params=[ + (rollback, run_second_execute, begin_nested) + for rollback in (True, False) + for run_second_execute in (True, False) + for begin_nested in (True, False) + ] + ) + def async_trans_ctx_manager_fixture(self, request, metadata): + rollback, run_second_execute, begin_nested = request.param + + from sqlalchemy import Table, Column, Integer, func, select + + t = Table("test", metadata, Column("data", Integer)) + eng = getattr(self, "bind", None) or config.db + + t.create(eng) + + async def run_test(subject, trans_on_subject, execute_on_subject): + async with subject.begin() as trans: + + if begin_nested: + if not config.requirements.savepoints.enabled: + config.skip_test("savepoints not enabled") + if execute_on_subject: + nested_trans = subject.begin_nested() + else: + nested_trans = trans.begin_nested() + + async with nested_trans: + if execute_on_subject: + await subject.execute(t.insert(), {"data": 10}) + else: + await trans.execute(t.insert(), {"data": 10}) + + # for nested trans, we always commit/rollback on the + # "nested trans" object itself. + # only Session(future=False) will affect savepoint + # transaction for session.commit/rollback + + if rollback: + await nested_trans.rollback() + else: + await nested_trans.commit() + + if run_second_execute: + with assertions.expect_raises_message( + exc.InvalidRequestError, + "Can't operate on closed transaction " + "inside context manager. Please complete the " + "context manager " + "before emitting further commands.", + ): + if execute_on_subject: + await subject.execute( + t.insert(), {"data": 12} + ) + else: + await trans.execute( + t.insert(), {"data": 12} + ) + + # outside the nested trans block, but still inside the + # transaction block, we can run SQL, and it will be + # committed + if execute_on_subject: + await subject.execute(t.insert(), {"data": 14}) + else: + await trans.execute(t.insert(), {"data": 14}) + + else: + if execute_on_subject: + await subject.execute(t.insert(), {"data": 10}) + else: + await trans.execute(t.insert(), {"data": 10}) + + if trans_on_subject: + if rollback: + await subject.rollback() + else: + await subject.commit() + else: + if rollback: + await trans.rollback() + else: + await trans.commit() + + if run_second_execute: + with assertions.expect_raises_message( + exc.InvalidRequestError, + "Can't operate on closed transaction inside " + "context " + "manager. Please complete the context manager " + "before emitting further commands.", + ): + if execute_on_subject: + await subject.execute(t.insert(), {"data": 12}) + else: + await trans.execute(t.insert(), {"data": 12}) + + expected_committed = 0 + if begin_nested: + # begin_nested variant, we inserted a row after the nested + # block + expected_committed += 1 + if not rollback: + # not rollback variant, our row inserted in the target + # block itself would be committed + expected_committed += 1 + + if execute_on_subject: + eq_( + await subject.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + else: + with subject.connect() as conn: + eq_( + await conn.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + + return run_test + + +class EngineFixture(AsyncFixture, fixtures.TablesTest): __requires__ = ("async_dialect",) @testing.fixture @@ -68,6 +196,15 @@ class AsyncEngineTest(EngineFixture): async with async_engine.connect() as conn: eq_(await conn.scalar(text("select 1")), 2) + @async_test + async def test_interrupt_ctxmanager_connection( + self, async_engine, async_trans_ctx_manager_fixture + ): + fn = async_trans_ctx_manager_fixture + + async with async_engine.connect() as conn: + await fn(conn, trans_on_subject=False, execute_on_subject=True) + def test_proxied_attrs_engine(self, async_engine): sync_engine = async_engine.sync_engine |