diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-05-02 18:31:03 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-05-05 22:21:07 -0400 |
commit | c5587fda7986df5851491a069830ddd4a63e01ba (patch) | |
tree | 3cbc7f7eb5a1c477032e20ace015a7502d5c2823 /test/ext/asyncio/test_engine_py3k.py | |
parent | ee7a82d71783bf71f3a95550624740e908d178a0 (diff) | |
download | sqlalchemy-c5587fda7986df5851491a069830ddd4a63e01ba.tar.gz |
unify transactional context managers
Applied consistent behavior to the use case of
calling ``.commit()`` or ``.rollback()`` inside of an existing
``.begin()`` context manager, with the addition of potentially
emitting SQL within the block subsequent to the commit or rollback.
This change continues upon the change first added in
:ticket:`6155` where the use case of calling "rollback" inside of
a ``.begin()`` contextmanager block was proposed:
* calling ``.commit()`` or ``.rollback()`` will now be allowed
without error or warning within all scopes, including
that of legacy and future :class:`_engine.Engine`, ORM
:class:`_orm.Session`, asyncio :class:`.AsyncEngine`. Previously,
the :class:`_orm.Session` disallowed this.
* The remaining scope of the context manager is then closed;
when the block ends, a check is emitted to see if the transaction
was already ended, and if so the block returns without action.
* It will now raise **an error** if subsequent SQL of any kind
is emitted within the block, **after** ``.commit()`` or
``.rollback()`` is called. The block should be closed as
the state of the executable object would otherwise be undefined
in this state.
Fixes: #6288
Change-Id: I8b21766ae430f0fa1ac5ef689f4c0fb19fc84336
Diffstat (limited to 'test/ext/asyncio/test_engine_py3k.py')
-rw-r--r-- | test/ext/asyncio/test_engine_py3k.py | 139 |
1 files changed, 138 insertions, 1 deletions
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 820c82bca..18e55ff92 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -17,8 +17,10 @@ from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import engine as _async_engine from sqlalchemy.ext.asyncio import exc as asyncio_exc from sqlalchemy.pool import AsyncAdaptedQueuePool +from sqlalchemy.testing import assertions from sqlalchemy.testing import async_test from sqlalchemy.testing import combinations +from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises @@ -34,7 +36,133 @@ from sqlalchemy.testing import ne_ from sqlalchemy.util.concurrency import greenlet_spawn -class EngineFixture(fixtures.TablesTest): +class AsyncFixture: + @config.fixture( + params=[ + (rollback, run_second_execute, begin_nested) + for rollback in (True, False) + for run_second_execute in (True, False) + for begin_nested in (True, False) + ] + ) + def async_trans_ctx_manager_fixture(self, request, metadata): + rollback, run_second_execute, begin_nested = request.param + + from sqlalchemy import Table, Column, Integer, func, select + + t = Table("test", metadata, Column("data", Integer)) + eng = getattr(self, "bind", None) or config.db + + t.create(eng) + + async def run_test(subject, trans_on_subject, execute_on_subject): + async with subject.begin() as trans: + + if begin_nested: + if not config.requirements.savepoints.enabled: + config.skip_test("savepoints not enabled") + if execute_on_subject: + nested_trans = subject.begin_nested() + else: + nested_trans = trans.begin_nested() + + async with nested_trans: + if execute_on_subject: + await subject.execute(t.insert(), {"data": 10}) + else: + await trans.execute(t.insert(), {"data": 10}) + + # for nested trans, we always commit/rollback on the + # "nested trans" object itself. + # only Session(future=False) will affect savepoint + # transaction for session.commit/rollback + + if rollback: + await nested_trans.rollback() + else: + await nested_trans.commit() + + if run_second_execute: + with assertions.expect_raises_message( + exc.InvalidRequestError, + "Can't operate on closed transaction " + "inside context manager. Please complete the " + "context manager " + "before emitting further commands.", + ): + if execute_on_subject: + await subject.execute( + t.insert(), {"data": 12} + ) + else: + await trans.execute( + t.insert(), {"data": 12} + ) + + # outside the nested trans block, but still inside the + # transaction block, we can run SQL, and it will be + # committed + if execute_on_subject: + await subject.execute(t.insert(), {"data": 14}) + else: + await trans.execute(t.insert(), {"data": 14}) + + else: + if execute_on_subject: + await subject.execute(t.insert(), {"data": 10}) + else: + await trans.execute(t.insert(), {"data": 10}) + + if trans_on_subject: + if rollback: + await subject.rollback() + else: + await subject.commit() + else: + if rollback: + await trans.rollback() + else: + await trans.commit() + + if run_second_execute: + with assertions.expect_raises_message( + exc.InvalidRequestError, + "Can't operate on closed transaction inside " + "context " + "manager. Please complete the context manager " + "before emitting further commands.", + ): + if execute_on_subject: + await subject.execute(t.insert(), {"data": 12}) + else: + await trans.execute(t.insert(), {"data": 12}) + + expected_committed = 0 + if begin_nested: + # begin_nested variant, we inserted a row after the nested + # block + expected_committed += 1 + if not rollback: + # not rollback variant, our row inserted in the target + # block itself would be committed + expected_committed += 1 + + if execute_on_subject: + eq_( + await subject.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + else: + with subject.connect() as conn: + eq_( + await conn.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + + return run_test + + +class EngineFixture(AsyncFixture, fixtures.TablesTest): __requires__ = ("async_dialect",) @testing.fixture @@ -68,6 +196,15 @@ class AsyncEngineTest(EngineFixture): async with async_engine.connect() as conn: eq_(await conn.scalar(text("select 1")), 2) + @async_test + async def test_interrupt_ctxmanager_connection( + self, async_engine, async_trans_ctx_manager_fixture + ): + fn = async_trans_ctx_manager_fixture + + async with async_engine.connect() as conn: + await fn(conn, trans_on_subject=False, execute_on_subject=True) + def test_proxied_attrs_engine(self, async_engine): sync_engine = async_engine.sync_engine |