diff options
Diffstat (limited to 'test/ext/asyncio/test_engine_py3k.py')
-rw-r--r-- | test/ext/asyncio/test_engine_py3k.py | 49 |
1 files changed, 46 insertions, 3 deletions
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index cdf70ca67..2eebb433d 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -799,6 +799,42 @@ class AsyncResultTest(EngineFixture): ): await conn.exec_driver_sql("SELECT * FROM users") + @async_test + async def test_stream_ctxmanager(self, async_engine): + async with async_engine.connect() as conn: + conn = await conn.execution_options(stream_results=True) + + async with conn.stream(select(self.tables.users)) as result: + assert not result._real_result._soft_closed + assert not result.closed + with expect_raises_message(Exception, "hi"): + i = 0 + async for row in result: + if i > 2: + raise Exception("hi") + i += 1 + assert result._real_result._soft_closed + assert result.closed + + @async_test + async def test_stream_scalars_ctxmanager(self, async_engine): + async with async_engine.connect() as conn: + conn = await conn.execution_options(stream_results=True) + + async with conn.stream_scalars( + select(self.tables.users) + ) as result: + assert not result._real_result._soft_closed + assert not result.closed + with expect_raises_message(Exception, "hi"): + i = 0 + async for scalar in result: + if i > 2: + raise Exception("hi") + i += 1 + assert result._real_result._soft_closed + assert result.closed + @testing.combinations( (None,), ("scalars",), ("mappings",), argnames="filter_" ) @@ -831,13 +867,20 @@ class AsyncResultTest(EngineFixture): eq_(all_, [(i, "name%d" % i) for i in range(1, 20)]) @testing.combinations( - (None,), ("scalars",), ("mappings",), argnames="filter_" + (None,), + ("scalars",), + ("stream_scalars",), + ("mappings",), + argnames="filter_", ) @async_test async def test_aiter(self, async_engine, filter_): users = self.tables.users async with async_engine.connect() as conn: - result = await conn.stream(select(users)) + if filter_ == "stream_scalars": + result = await conn.stream_scalars(select(users.c.user_name)) + else: + result = await conn.stream(select(users)) if filter_ == "mappings": result = result.mappings() @@ -857,7 +900,7 @@ class AsyncResultTest(EngineFixture): for i in range(1, 20) ], ) - elif filter_ == "scalars": + elif filter_ in ("scalars", "stream_scalars"): eq_( rows, ["name%d" % i for i in range(1, 20)], |