summaryrefslogtreecommitdiff
path: root/test/ext/asyncio/test_engine_py3k.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-10-25 09:10:09 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-11-03 18:42:52 -0400
commitb96321ae79a0366c33ca739e6e67aaf5f4420db4 (patch)
treed56cb4cdf58e0b060f1ceb14f468eef21de0688b /test/ext/asyncio/test_engine_py3k.py
parent9bae9a931a460ff70172858ff90bcc1defae8e20 (diff)
downloadsqlalchemy-b96321ae79a0366c33ca739e6e67aaf5f4420db4.tar.gz
Support result.close() for all iterator patterns
This change contains new features for 2.0 only as well as some behaviors that will be backported to 1.4. For 1.4 and 2.0: Fixed issue where the underlying DBAPI cursor would not be closed when using :class:`_orm.Query` with :meth:`_orm.Query.yield_per` and direct iteration, if a user-defined exception case were raised within the iteration process, interrupting the iterator. This would lead to the usual MySQL-related issues with server side cursors out of sync. For 1.4 only: A similar scenario can occur when using :term:`2.x` executions with direct use of :class:`.Result`, in that case the end-user code has access to the :class:`.Result` itself and should call :meth:`.Result.close` directly. Version 2.0 will feature context-manager calling patterns to address this use case. However within the 1.4 scope, ensured that ``.close()`` methods are available on all :class:`.Result` implementations including :class:`.ScalarResult`, :class:`.MappingResult`. For 2.0 only: To better support the use case of iterating :class:`.Result` and :class:`.AsyncResult` objects where user-defined exceptions may interrupt the iteration, both objects as well as variants such as :class:`.ScalarResult`, :class:`.MappingResult`, :class:`.AsyncScalarResult`, :class:`.AsyncMappingResult` now support context manager usage, where the result will be closed at the end of iteration. Corrected various typing issues within the engine and async engine packages. Fixes: #8710 Change-Id: I3166328bfd3900957eb33cbf1061d0495c9df670
Diffstat (limited to 'test/ext/asyncio/test_engine_py3k.py')
-rw-r--r--test/ext/asyncio/test_engine_py3k.py49
1 files changed, 46 insertions, 3 deletions
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py
index cdf70ca67..2eebb433d 100644
--- a/test/ext/asyncio/test_engine_py3k.py
+++ b/test/ext/asyncio/test_engine_py3k.py
@@ -799,6 +799,42 @@ class AsyncResultTest(EngineFixture):
):
await conn.exec_driver_sql("SELECT * FROM users")
+ @async_test
+ async def test_stream_ctxmanager(self, async_engine):
+ async with async_engine.connect() as conn:
+ conn = await conn.execution_options(stream_results=True)
+
+ async with conn.stream(select(self.tables.users)) as result:
+ assert not result._real_result._soft_closed
+ assert not result.closed
+ with expect_raises_message(Exception, "hi"):
+ i = 0
+ async for row in result:
+ if i > 2:
+ raise Exception("hi")
+ i += 1
+ assert result._real_result._soft_closed
+ assert result.closed
+
+ @async_test
+ async def test_stream_scalars_ctxmanager(self, async_engine):
+ async with async_engine.connect() as conn:
+ conn = await conn.execution_options(stream_results=True)
+
+ async with conn.stream_scalars(
+ select(self.tables.users)
+ ) as result:
+ assert not result._real_result._soft_closed
+ assert not result.closed
+ with expect_raises_message(Exception, "hi"):
+ i = 0
+ async for scalar in result:
+ if i > 2:
+ raise Exception("hi")
+ i += 1
+ assert result._real_result._soft_closed
+ assert result.closed
+
@testing.combinations(
(None,), ("scalars",), ("mappings",), argnames="filter_"
)
@@ -831,13 +867,20 @@ class AsyncResultTest(EngineFixture):
eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
@testing.combinations(
- (None,), ("scalars",), ("mappings",), argnames="filter_"
+ (None,),
+ ("scalars",),
+ ("stream_scalars",),
+ ("mappings",),
+ argnames="filter_",
)
@async_test
async def test_aiter(self, async_engine, filter_):
users = self.tables.users
async with async_engine.connect() as conn:
- result = await conn.stream(select(users))
+ if filter_ == "stream_scalars":
+ result = await conn.stream_scalars(select(users.c.user_name))
+ else:
+ result = await conn.stream(select(users))
if filter_ == "mappings":
result = result.mappings()
@@ -857,7 +900,7 @@ class AsyncResultTest(EngineFixture):
for i in range(1, 20)
],
)
- elif filter_ == "scalars":
+ elif filter_ in ("scalars", "stream_scalars"):
eq_(
rows,
["name%d" % i for i in range(1, 20)],