diff options
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/engine.py')
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 162 |
1 files changed, 105 insertions, 57 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 389088875..854039c98 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -7,7 +7,9 @@ from __future__ import annotations import asyncio +import contextlib from typing import Any +from typing import AsyncIterator from typing import Callable from typing import Dict from typing import Generator @@ -21,6 +23,8 @@ from typing import TypeVar from typing import Union from . import exc as async_exc +from .base import asyncstartablecontext +from .base import GeneratorStartableContext from .base import ProxyComparable from .base import StartableContext from .result import _ensure_sync_result @@ -133,7 +137,9 @@ class AsyncConnectable: ], ) class AsyncConnection( - ProxyComparable[Connection], StartableContext, AsyncConnectable + ProxyComparable[Connection], + StartableContext["AsyncConnection"], + AsyncConnectable, ): """An asyncio proxy for a :class:`_engine.Connection`. @@ -446,34 +452,67 @@ class AsyncConnection( return await _ensure_sync_result(result, self.exec_driver_sql) @overload - async def stream( + def stream( self, statement: TypedReturnsRows[_T], parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncResult[_T]: + ) -> GeneratorStartableContext[AsyncResult[_T]]: ... @overload - async def stream( + def stream( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncResult[Any]: + ) -> GeneratorStartableContext[AsyncResult[Any]]: ... + @asyncstartablecontext async def stream( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncResult[Any]: - """Execute a statement and return a streaming - :class:`_asyncio.AsyncResult` object.""" + ) -> AsyncIterator[AsyncResult[Any]]: + """Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncResult` object. + + E.g.:: + + result = await conn.stream(stmt): + async for row in result: + print(f"{row}") + + The :meth:`.AsyncConnection.stream` + method supports optional context manager use against the + :class:`.AsyncResult` object, as in:: + + async with conn.stream(stmt) as result: + async for row in result: + print(f"{row}") + + In the above pattern, the :meth:`.AsyncResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncResult` object. + + .. seealso:: + + :meth:`.AsyncConnection.stream_scalars` + + """ result = await greenlet_spawn( self._proxied.execute, @@ -484,10 +523,15 @@ class AsyncConnection( ), _require_await=True, ) - if not result.context._is_server_side: - # TODO: real exception here - assert False, "server side result expected" - return AsyncResult(result) + assert result.context._is_server_side + ar = AsyncResult(result) + try: + yield ar + except GeneratorExit: + pass + else: + task = asyncio.create_task(ar.close()) + await asyncio.shield(task) @overload async def execute( @@ -642,48 +686,77 @@ class AsyncConnection( return result.scalars() @overload - async def stream_scalars( + def stream_scalars( self, statement: TypedReturnsRows[Tuple[_T]], parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncScalarResult[_T]: + ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: ... @overload - async def stream_scalars( + def stream_scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncScalarResult[Any]: + ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: ... + @asyncstartablecontext async def stream_scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncScalarResult[Any]: - r"""Executes a SQL statement and returns a streaming scalar result - object. + ) -> AsyncIterator[AsyncScalarResult[Any]]: + r"""Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncScalarResult` object. + + E.g.:: + + result = await conn.stream_scalars(stmt): + async for scalar in result: + print(f"{scalar}") This method is shorthand for invoking the :meth:`_engine.AsyncResult.scalars` method after invoking the :meth:`_engine.Connection.stream` method. Parameters are equivalent. - :return: an :class:`_asyncio.AsyncScalarResult` object. + The :meth:`.AsyncConnection.stream_scalars` + method supports optional context manager use against the + :class:`.AsyncScalarResult` object, as in:: + + async with conn.stream_scalars(stmt) as result: + async for scalar in result: + print(f"{scalar}") + + In the above pattern, the :meth:`.AsyncScalarResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncScalarResult` object. .. versionadded:: 1.4.24 + .. seealso:: + + :meth:`.AsyncConnection.stream` + """ - result = await self.stream( + + async with self.stream( statement, parameters, execution_options=execution_options - ) - return result.scalars() + ) as result: + yield result.scalars() async def run_sync( self, fn: Callable[..., Any], *arg: Any, **kw: Any @@ -856,37 +929,6 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): :ref:`asyncio_events` """ - class _trans_ctx(StartableContext): - __slots__ = ("conn", "transaction") - - conn: AsyncConnection - transaction: AsyncTransaction - - def __init__(self, conn: AsyncConnection): - self.conn = conn - - if TYPE_CHECKING: - - async def __aenter__(self) -> AsyncConnection: - ... - - async def start(self, is_ctxmanager: bool = False) -> AsyncConnection: - await self.conn.start(is_ctxmanager=is_ctxmanager) - self.transaction = self.conn.begin() - await self.transaction.__aenter__() - - return self.conn - - async def __aexit__( - self, type_: Any, value: Any, traceback: Any - ) -> None: - async def go() -> None: - await self.transaction.__aexit__(type_, value, traceback) - await self.conn.close() - - task = asyncio.create_task(go()) - await asyncio.shield(task) - def __init__(self, sync_engine: Engine): if not sync_engine.dialect.is_async: raise exc.InvalidRequestError( @@ -903,7 +945,8 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: return AsyncEngine(target) - def begin(self) -> AsyncEngine._trans_ctx: + @contextlib.asynccontextmanager + async def begin(self) -> AsyncIterator[AsyncConnection]: """Return a context manager which when entered will deliver an :class:`_asyncio.AsyncConnection` with an :class:`_asyncio.AsyncTransaction` established. @@ -919,7 +962,10 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): """ conn = self.connect() - return self._trans_ctx(conn) + + async with conn: + async with conn.begin(): + yield conn def connect(self) -> AsyncConnection: """Return an :class:`_asyncio.AsyncConnection` object. @@ -1185,7 +1231,9 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): # END PROXY METHODS AsyncEngine -class AsyncTransaction(ProxyComparable[Transaction], StartableContext): +class AsyncTransaction( + ProxyComparable[Transaction], StartableContext["AsyncTransaction"] +): """An asyncio proxy for a :class:`_engine.Transaction`.""" __slots__ = ("connection", "sync_transaction", "nested") |