diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-10-25 09:10:09 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-11-03 18:42:52 -0400 |
commit | b96321ae79a0366c33ca739e6e67aaf5f4420db4 (patch) | |
tree | d56cb4cdf58e0b060f1ceb14f468eef21de0688b /test/ext/asyncio/test_engine_py3k.py | |
parent | 9bae9a931a460ff70172858ff90bcc1defae8e20 (diff) | |
download | sqlalchemy-b96321ae79a0366c33ca739e6e67aaf5f4420db4.tar.gz |
Support result.close() for all iterator patterns
This change contains new features for 2.0 only as well as some
behaviors that will be backported to 1.4.
For 1.4 and 2.0:
Fixed issue where the underlying DBAPI cursor would not be closed when
using :class:`_orm.Query` with :meth:`_orm.Query.yield_per` and direct
iteration, if a user-defined exception case were raised within the
iteration process, interrupting the iterator. This would lead to the usual
MySQL-related issues with server side cursors out of sync.
For 1.4 only:
A similar scenario can occur when using :term:`2.x` executions with direct
use of :class:`.Result`, in that case the end-user code has access to the
:class:`.Result` itself and should call :meth:`.Result.close` directly.
Version 2.0 will feature context-manager calling patterns to address this
use case. However within the 1.4 scope, ensured that ``.close()`` methods
are available on all :class:`.Result` implementations including
:class:`.ScalarResult`, :class:`.MappingResult`.
For 2.0 only:
To better support the use case of iterating :class:`.Result` and
:class:`.AsyncResult` objects where user-defined exceptions may interrupt
the iteration, both objects as well as variants such as
:class:`.ScalarResult`, :class:`.MappingResult`,
:class:`.AsyncScalarResult`, :class:`.AsyncMappingResult` now support
context manager usage, where the result will be closed at the end of
iteration.
Corrected various typing issues within the engine and async engine
packages.
Fixes: #8710
Change-Id: I3166328bfd3900957eb33cbf1061d0495c9df670
Diffstat (limited to 'test/ext/asyncio/test_engine_py3k.py')
-rw-r--r-- | test/ext/asyncio/test_engine_py3k.py | 49 |
1 files changed, 46 insertions, 3 deletions
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index cdf70ca67..2eebb433d 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -799,6 +799,42 @@ class AsyncResultTest(EngineFixture): ): await conn.exec_driver_sql("SELECT * FROM users") + @async_test + async def test_stream_ctxmanager(self, async_engine): + async with async_engine.connect() as conn: + conn = await conn.execution_options(stream_results=True) + + async with conn.stream(select(self.tables.users)) as result: + assert not result._real_result._soft_closed + assert not result.closed + with expect_raises_message(Exception, "hi"): + i = 0 + async for row in result: + if i > 2: + raise Exception("hi") + i += 1 + assert result._real_result._soft_closed + assert result.closed + + @async_test + async def test_stream_scalars_ctxmanager(self, async_engine): + async with async_engine.connect() as conn: + conn = await conn.execution_options(stream_results=True) + + async with conn.stream_scalars( + select(self.tables.users) + ) as result: + assert not result._real_result._soft_closed + assert not result.closed + with expect_raises_message(Exception, "hi"): + i = 0 + async for scalar in result: + if i > 2: + raise Exception("hi") + i += 1 + assert result._real_result._soft_closed + assert result.closed + @testing.combinations( (None,), ("scalars",), ("mappings",), argnames="filter_" ) @@ -831,13 +867,20 @@ class AsyncResultTest(EngineFixture): eq_(all_, [(i, "name%d" % i) for i in range(1, 20)]) @testing.combinations( - (None,), ("scalars",), ("mappings",), argnames="filter_" + (None,), + ("scalars",), + ("stream_scalars",), + ("mappings",), + argnames="filter_", ) @async_test async def test_aiter(self, async_engine, filter_): users = self.tables.users async with async_engine.connect() as conn: - result = await conn.stream(select(users)) + if filter_ == "stream_scalars": + result = await conn.stream_scalars(select(users.c.user_name)) + else: + result = await conn.stream(select(users)) if filter_ == "mappings": result = result.mappings() @@ -857,7 +900,7 @@ class AsyncResultTest(EngineFixture): for i in range(1, 20) ], ) - elif filter_ == "scalars": + elif filter_ in ("scalars", "stream_scalars"): eq_( rows, ["name%d" % i for i in range(1, 20)], |