summaryrefslogtreecommitdiff
path: root/test/ext/asyncio/test_engine_py3k.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/ext/asyncio/test_engine_py3k.py')
-rw-r--r--test/ext/asyncio/test_engine_py3k.py49
1 files changed, 46 insertions, 3 deletions
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py
index cdf70ca67..2eebb433d 100644
--- a/test/ext/asyncio/test_engine_py3k.py
+++ b/test/ext/asyncio/test_engine_py3k.py
@@ -799,6 +799,42 @@ class AsyncResultTest(EngineFixture):
):
await conn.exec_driver_sql("SELECT * FROM users")
+ @async_test
+ async def test_stream_ctxmanager(self, async_engine):
+ async with async_engine.connect() as conn:
+ conn = await conn.execution_options(stream_results=True)
+
+ async with conn.stream(select(self.tables.users)) as result:
+ assert not result._real_result._soft_closed
+ assert not result.closed
+ with expect_raises_message(Exception, "hi"):
+ i = 0
+ async for row in result:
+ if i > 2:
+ raise Exception("hi")
+ i += 1
+ assert result._real_result._soft_closed
+ assert result.closed
+
+ @async_test
+ async def test_stream_scalars_ctxmanager(self, async_engine):
+ async with async_engine.connect() as conn:
+ conn = await conn.execution_options(stream_results=True)
+
+ async with conn.stream_scalars(
+ select(self.tables.users)
+ ) as result:
+ assert not result._real_result._soft_closed
+ assert not result.closed
+ with expect_raises_message(Exception, "hi"):
+ i = 0
+ async for scalar in result:
+ if i > 2:
+ raise Exception("hi")
+ i += 1
+ assert result._real_result._soft_closed
+ assert result.closed
+
@testing.combinations(
(None,), ("scalars",), ("mappings",), argnames="filter_"
)
@@ -831,13 +867,20 @@ class AsyncResultTest(EngineFixture):
eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
@testing.combinations(
- (None,), ("scalars",), ("mappings",), argnames="filter_"
+ (None,),
+ ("scalars",),
+ ("stream_scalars",),
+ ("mappings",),
+ argnames="filter_",
)
@async_test
async def test_aiter(self, async_engine, filter_):
users = self.tables.users
async with async_engine.connect() as conn:
- result = await conn.stream(select(users))
+ if filter_ == "stream_scalars":
+ result = await conn.stream_scalars(select(users.c.user_name))
+ else:
+ result = await conn.stream(select(users))
if filter_ == "mappings":
result = result.mappings()
@@ -857,7 +900,7 @@ class AsyncResultTest(EngineFixture):
for i in range(1, 20)
],
)
- elif filter_ == "scalars":
+ elif filter_ in ("scalars", "stream_scalars"):
eq_(
rows,
["name%d" % i for i in range(1, 20)],