summaryrefslogtreecommitdiff
path: root/test/ext/asyncio/test_engine_py3k.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/ext/asyncio/test_engine_py3k.py')
-rw-r--r--test/ext/asyncio/test_engine_py3k.py139
1 files changed, 138 insertions, 1 deletions
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py
index 820c82bca..18e55ff92 100644
--- a/test/ext/asyncio/test_engine_py3k.py
+++ b/test/ext/asyncio/test_engine_py3k.py
@@ -17,8 +17,10 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import engine as _async_engine
from sqlalchemy.ext.asyncio import exc as asyncio_exc
from sqlalchemy.pool import AsyncAdaptedQueuePool
+from sqlalchemy.testing import assertions
from sqlalchemy.testing import async_test
from sqlalchemy.testing import combinations
+from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises
@@ -34,7 +36,133 @@ from sqlalchemy.testing import ne_
from sqlalchemy.util.concurrency import greenlet_spawn
-class EngineFixture(fixtures.TablesTest):
+class AsyncFixture:
+ @config.fixture(
+ params=[
+ (rollback, run_second_execute, begin_nested)
+ for rollback in (True, False)
+ for run_second_execute in (True, False)
+ for begin_nested in (True, False)
+ ]
+ )
+ def async_trans_ctx_manager_fixture(self, request, metadata):
+ rollback, run_second_execute, begin_nested = request.param
+
+ from sqlalchemy import Table, Column, Integer, func, select
+
+ t = Table("test", metadata, Column("data", Integer))
+ eng = getattr(self, "bind", None) or config.db
+
+ t.create(eng)
+
+ async def run_test(subject, trans_on_subject, execute_on_subject):
+ async with subject.begin() as trans:
+
+ if begin_nested:
+ if not config.requirements.savepoints.enabled:
+ config.skip_test("savepoints not enabled")
+ if execute_on_subject:
+ nested_trans = subject.begin_nested()
+ else:
+ nested_trans = trans.begin_nested()
+
+ async with nested_trans:
+ if execute_on_subject:
+ await subject.execute(t.insert(), {"data": 10})
+ else:
+ await trans.execute(t.insert(), {"data": 10})
+
+ # for nested trans, we always commit/rollback on the
+ # "nested trans" object itself.
+ # only Session(future=False) will affect savepoint
+ # transaction for session.commit/rollback
+
+ if rollback:
+ await nested_trans.rollback()
+ else:
+ await nested_trans.commit()
+
+ if run_second_execute:
+ with assertions.expect_raises_message(
+ exc.InvalidRequestError,
+ "Can't operate on closed transaction "
+ "inside context manager. Please complete the "
+ "context manager "
+ "before emitting further commands.",
+ ):
+ if execute_on_subject:
+ await subject.execute(
+ t.insert(), {"data": 12}
+ )
+ else:
+ await trans.execute(
+ t.insert(), {"data": 12}
+ )
+
+ # outside the nested trans block, but still inside the
+ # transaction block, we can run SQL, and it will be
+ # committed
+ if execute_on_subject:
+ await subject.execute(t.insert(), {"data": 14})
+ else:
+ await trans.execute(t.insert(), {"data": 14})
+
+ else:
+ if execute_on_subject:
+ await subject.execute(t.insert(), {"data": 10})
+ else:
+ await trans.execute(t.insert(), {"data": 10})
+
+ if trans_on_subject:
+ if rollback:
+ await subject.rollback()
+ else:
+ await subject.commit()
+ else:
+ if rollback:
+ await trans.rollback()
+ else:
+ await trans.commit()
+
+ if run_second_execute:
+ with assertions.expect_raises_message(
+ exc.InvalidRequestError,
+ "Can't operate on closed transaction inside "
+ "context "
+ "manager. Please complete the context manager "
+ "before emitting further commands.",
+ ):
+ if execute_on_subject:
+ await subject.execute(t.insert(), {"data": 12})
+ else:
+ await trans.execute(t.insert(), {"data": 12})
+
+ expected_committed = 0
+ if begin_nested:
+ # begin_nested variant, we inserted a row after the nested
+ # block
+ expected_committed += 1
+ if not rollback:
+ # not rollback variant, our row inserted in the target
+ # block itself would be committed
+ expected_committed += 1
+
+ if execute_on_subject:
+ eq_(
+ await subject.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+ else:
+ with subject.connect() as conn:
+ eq_(
+ await conn.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+
+ return run_test
+
+
+class EngineFixture(AsyncFixture, fixtures.TablesTest):
__requires__ = ("async_dialect",)
@testing.fixture
@@ -68,6 +196,15 @@ class AsyncEngineTest(EngineFixture):
async with async_engine.connect() as conn:
eq_(await conn.scalar(text("select 1")), 2)
+ @async_test
+ async def test_interrupt_ctxmanager_connection(
+ self, async_engine, async_trans_ctx_manager_fixture
+ ):
+ fn = async_trans_ctx_manager_fixture
+
+ async with async_engine.connect() as conn:
+ await fn(conn, trans_on_subject=False, execute_on_subject=True)
+
def test_proxied_attrs_engine(self, async_engine):
sync_engine = async_engine.sync_engine