summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-10-25 09:10:09 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-11-03 18:42:52 -0400
commitb96321ae79a0366c33ca739e6e67aaf5f4420db4 (patch)
treed56cb4cdf58e0b060f1ceb14f468eef21de0688b /lib/sqlalchemy/ext/asyncio
parent9bae9a931a460ff70172858ff90bcc1defae8e20 (diff)
downloadsqlalchemy-b96321ae79a0366c33ca739e6e67aaf5f4420db4.tar.gz
Support result.close() for all iterator patterns
This change contains new features for 2.0 only as well as some behaviors that will be backported to 1.4. For 1.4 and 2.0: Fixed issue where the underlying DBAPI cursor would not be closed when using :class:`_orm.Query` with :meth:`_orm.Query.yield_per` and direct iteration, if a user-defined exception case were raised within the iteration process, interrupting the iterator. This would lead to the usual MySQL-related issues with server side cursors out of sync. For 1.4 only: A similar scenario can occur when using :term:`2.x` executions with direct use of :class:`.Result`, in that case the end-user code has access to the :class:`.Result` itself and should call :meth:`.Result.close` directly. Version 2.0 will feature context-manager calling patterns to address this use case. However within the 1.4 scope, ensured that ``.close()`` methods are available on all :class:`.Result` implementations including :class:`.ScalarResult`, :class:`.MappingResult`. For 2.0 only: To better support the use case of iterating :class:`.Result` and :class:`.AsyncResult` objects where user-defined exceptions may interrupt the iteration, both objects as well as variants such as :class:`.ScalarResult`, :class:`.MappingResult`, :class:`.AsyncScalarResult`, :class:`.AsyncMappingResult` now support context manager usage, where the result will be closed at the end of iteration. Corrected various typing issues within the engine and async engine packages. Fixes: #8710 Change-Id: I3166328bfd3900957eb33cbf1061d0495c9df670
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio')
-rw-r--r--lib/sqlalchemy/ext/asyncio/base.py145
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py162
-rw-r--r--lib/sqlalchemy/ext/asyncio/result.py22
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py3
4 files changed, 266 insertions, 66 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py
index 7fdd2d7e0..13d5e40b2 100644
--- a/lib/sqlalchemy/ext/asyncio/base.py
+++ b/lib/sqlalchemy/ext/asyncio/base.py
@@ -10,8 +10,13 @@ from __future__ import annotations
import abc
import functools
from typing import Any
+from typing import AsyncGenerator
+from typing import AsyncIterator
+from typing import Awaitable
+from typing import Callable
from typing import ClassVar
from typing import Dict
+from typing import Generator
from typing import Generic
from typing import NoReturn
from typing import Optional
@@ -25,6 +30,7 @@ from ... import util
from ...util.typing import Literal
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_PT = TypeVar("_PT", bound=Any)
@@ -114,27 +120,29 @@ class ReversibleProxy(Generic[_PT]):
SelfStartableContext = TypeVar(
- "SelfStartableContext", bound="StartableContext"
+ "SelfStartableContext", bound="StartableContext[Any]"
)
-class StartableContext(abc.ABC):
+class StartableContext(Awaitable[_T_co], abc.ABC):
__slots__ = ()
@abc.abstractmethod
async def start(
self: SelfStartableContext, is_ctxmanager: bool = False
- ) -> Any:
+ ) -> _T_co:
raise NotImplementedError()
- def __await__(self) -> Any:
+ def __await__(self) -> Generator[Any, Any, _T_co]:
return self.start().__await__()
- async def __aenter__(self: SelfStartableContext) -> Any:
- return await self.start(is_ctxmanager=True)
+ async def __aenter__(self: SelfStartableContext) -> _T_co:
+ return await self.start(is_ctxmanager=True) # type: ignore
@abc.abstractmethod
- async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
+ async def __aexit__(
+ self, type_: Any, value: Any, traceback: Any
+ ) -> Optional[bool]:
pass
def _raise_for_not_started(self) -> NoReturn:
@@ -144,6 +152,129 @@ class StartableContext(abc.ABC):
)
+class GeneratorStartableContext(StartableContext[_T_co]):
+ __slots__ = ("gen",)
+
+ gen: AsyncGenerator[_T_co, Any]
+
+ def __init__(
+ self,
+ func: Callable[..., AsyncIterator[_T_co]],
+ args: tuple[Any, ...],
+ kwds: dict[str, Any],
+ ):
+ self.gen = func(*args, **kwds) # type: ignore
+
+ async def start(self, is_ctxmanager: bool = False) -> _T_co:
+ try:
+ start_value = await util.anext_(self.gen)
+ except StopAsyncIteration:
+ raise RuntimeError("generator didn't yield") from None
+
+ # if not a context manager, then interrupt the generator, don't
+ # let it complete. this step is technically not needed, as the
+ # generator will close in any case at gc time. not clear if having
+ # this here is a good idea or not (though it helps for clarity IMO)
+ if not is_ctxmanager:
+ await self.gen.aclose()
+
+ return start_value
+
+ async def __aexit__(
+ self, typ: Any, value: Any, traceback: Any
+ ) -> Optional[bool]:
+ # vendored from contextlib.py
+ if typ is None:
+ try:
+ await util.anext_(self.gen)
+ except StopAsyncIteration:
+ return False
+ else:
+ raise RuntimeError("generator didn't stop")
+ else:
+ if value is None:
+ # Need to force instantiation so we can reliably
+ # tell if we get the same exception back
+ value = typ()
+ try:
+ await self.gen.athrow(typ, value, traceback)
+ except StopAsyncIteration as exc:
+ # Suppress StopIteration *unless* it's the same exception that
+ # was passed to throw(). This prevents a StopIteration
+ # raised inside the "with" statement from being suppressed.
+ return exc is not value
+ except RuntimeError as exc:
+ # Don't re-raise the passed in exception. (issue27122)
+ if exc is value:
+ return False
+ # Avoid suppressing if a Stop(Async)Iteration exception
+ # was passed to athrow() and later wrapped into a RuntimeError
+ # (see PEP 479 for sync generators; async generators also
+ # have this behavior). But do this only if the exception
+ # wrapped
+ # by the RuntimeError is actully Stop(Async)Iteration (see
+ # issue29692).
+ if (
+ isinstance(value, (StopIteration, StopAsyncIteration))
+ and exc.__cause__ is value
+ ):
+ return False
+ raise
+ except BaseException as exc:
+ # only re-raise if it's *not* the exception that was
+ # passed to throw(), because __exit__() must not raise
+ # an exception unless __exit__() itself failed. But throw()
+ # has to raise the exception to signal propagation, so this
+ # fixes the impedance mismatch between the throw() protocol
+ # and the __exit__() protocol.
+ if exc is not value:
+ raise
+ return False
+ raise RuntimeError("generator didn't stop after athrow()")
+
+
+def asyncstartablecontext(
+ func: Callable[..., AsyncIterator[_T_co]]
+) -> Callable[..., GeneratorStartableContext[_T_co]]:
+ """@asyncstartablecontext decorator.
+
+ the decorated function can be called either as ``async with fn()``, **or**
+ ``await fn()``. This is decidedly different from what
+ ``@contextlib.asynccontextmanager`` supports, and the usage pattern
+ is different as well.
+
+ Typical usage::
+
+ @asyncstartablecontext
+ async def some_async_generator(<arguments>):
+ <setup>
+ try:
+ yield <value>
+ except GeneratorExit:
+ # return value was awaited, no context manager is present
+ # and caller will .close() the resource explicitly
+ pass
+ else:
+ <context manager cleanup>
+
+
+ Above, ``GeneratorExit`` is caught if the function were used as an
+ ``await``. In this case, it's essential that the cleanup does **not**
+ occur, so there should not be a ``finally`` block.
+
+ If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__``
+ and we were invoked as a context manager, and cleanup should proceed.
+
+
+ """
+
+ @functools.wraps(func)
+ def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]:
+ return GeneratorStartableContext(func, args, kwds)
+
+ return helper
+
+
class ProxyComparable(ReversibleProxy[_PT]):
__slots__ = ()
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index 389088875..854039c98 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -7,7 +7,9 @@
from __future__ import annotations
import asyncio
+import contextlib
from typing import Any
+from typing import AsyncIterator
from typing import Callable
from typing import Dict
from typing import Generator
@@ -21,6 +23,8 @@ from typing import TypeVar
from typing import Union
from . import exc as async_exc
+from .base import asyncstartablecontext
+from .base import GeneratorStartableContext
from .base import ProxyComparable
from .base import StartableContext
from .result import _ensure_sync_result
@@ -133,7 +137,9 @@ class AsyncConnectable:
],
)
class AsyncConnection(
- ProxyComparable[Connection], StartableContext, AsyncConnectable
+ ProxyComparable[Connection],
+ StartableContext["AsyncConnection"],
+ AsyncConnectable,
):
"""An asyncio proxy for a :class:`_engine.Connection`.
@@ -446,34 +452,67 @@ class AsyncConnection(
return await _ensure_sync_result(result, self.exec_driver_sql)
@overload
- async def stream(
+ def stream(
self,
statement: TypedReturnsRows[_T],
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
- ) -> AsyncResult[_T]:
+ ) -> GeneratorStartableContext[AsyncResult[_T]]:
...
@overload
- async def stream(
+ def stream(
self,
statement: Executable,
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
- ) -> AsyncResult[Any]:
+ ) -> GeneratorStartableContext[AsyncResult[Any]]:
...
+ @asyncstartablecontext
async def stream(
self,
statement: Executable,
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
- ) -> AsyncResult[Any]:
- """Execute a statement and return a streaming
- :class:`_asyncio.AsyncResult` object."""
+ ) -> AsyncIterator[AsyncResult[Any]]:
+ """Execute a statement and return an awaitable yielding a
+ :class:`_asyncio.AsyncResult` object.
+
+ E.g.::
+
+ result = await conn.stream(stmt):
+ async for row in result:
+ print(f"{row}")
+
+ The :meth:`.AsyncConnection.stream`
+ method supports optional context manager use against the
+ :class:`.AsyncResult` object, as in::
+
+ async with conn.stream(stmt) as result:
+ async for row in result:
+ print(f"{row}")
+
+ In the above pattern, the :meth:`.AsyncResult.close` method is
+ invoked unconditionally, even if the iterator is interrupted by an
+ exception throw. Context manager use remains optional, however,
+ and the function may be called in either an ``async with fn():`` or
+ ``await fn()`` style.
+
+ .. versionadded:: 2.0.0b3 added context manager support
+
+
+ :return: an awaitable object that will yield an
+ :class:`_asyncio.AsyncResult` object.
+
+ .. seealso::
+
+ :meth:`.AsyncConnection.stream_scalars`
+
+ """
result = await greenlet_spawn(
self._proxied.execute,
@@ -484,10 +523,15 @@ class AsyncConnection(
),
_require_await=True,
)
- if not result.context._is_server_side:
- # TODO: real exception here
- assert False, "server side result expected"
- return AsyncResult(result)
+ assert result.context._is_server_side
+ ar = AsyncResult(result)
+ try:
+ yield ar
+ except GeneratorExit:
+ pass
+ else:
+ task = asyncio.create_task(ar.close())
+ await asyncio.shield(task)
@overload
async def execute(
@@ -642,48 +686,77 @@ class AsyncConnection(
return result.scalars()
@overload
- async def stream_scalars(
+ def stream_scalars(
self,
statement: TypedReturnsRows[Tuple[_T]],
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
- ) -> AsyncScalarResult[_T]:
+ ) -> GeneratorStartableContext[AsyncScalarResult[_T]]:
...
@overload
- async def stream_scalars(
+ def stream_scalars(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
- ) -> AsyncScalarResult[Any]:
+ ) -> GeneratorStartableContext[AsyncScalarResult[Any]]:
...
+ @asyncstartablecontext
async def stream_scalars(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
- ) -> AsyncScalarResult[Any]:
- r"""Executes a SQL statement and returns a streaming scalar result
- object.
+ ) -> AsyncIterator[AsyncScalarResult[Any]]:
+ r"""Execute a statement and return an awaitable yielding a
+ :class:`_asyncio.AsyncScalarResult` object.
+
+ E.g.::
+
+ result = await conn.stream_scalars(stmt):
+ async for scalar in result:
+ print(f"{scalar}")
This method is shorthand for invoking the
:meth:`_engine.AsyncResult.scalars` method after invoking the
:meth:`_engine.Connection.stream` method. Parameters are equivalent.
- :return: an :class:`_asyncio.AsyncScalarResult` object.
+ The :meth:`.AsyncConnection.stream_scalars`
+ method supports optional context manager use against the
+ :class:`.AsyncScalarResult` object, as in::
+
+ async with conn.stream_scalars(stmt) as result:
+ async for scalar in result:
+ print(f"{scalar}")
+
+ In the above pattern, the :meth:`.AsyncScalarResult.close` method is
+ invoked unconditionally, even if the iterator is interrupted by an
+ exception throw. Context manager use remains optional, however,
+ and the function may be called in either an ``async with fn():`` or
+ ``await fn()`` style.
+
+ .. versionadded:: 2.0.0b3 added context manager support
+
+ :return: an awaitable object that will yield an
+ :class:`_asyncio.AsyncScalarResult` object.
.. versionadded:: 1.4.24
+ .. seealso::
+
+ :meth:`.AsyncConnection.stream`
+
"""
- result = await self.stream(
+
+ async with self.stream(
statement, parameters, execution_options=execution_options
- )
- return result.scalars()
+ ) as result:
+ yield result.scalars()
async def run_sync(
self, fn: Callable[..., Any], *arg: Any, **kw: Any
@@ -856,37 +929,6 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
:ref:`asyncio_events`
"""
- class _trans_ctx(StartableContext):
- __slots__ = ("conn", "transaction")
-
- conn: AsyncConnection
- transaction: AsyncTransaction
-
- def __init__(self, conn: AsyncConnection):
- self.conn = conn
-
- if TYPE_CHECKING:
-
- async def __aenter__(self) -> AsyncConnection:
- ...
-
- async def start(self, is_ctxmanager: bool = False) -> AsyncConnection:
- await self.conn.start(is_ctxmanager=is_ctxmanager)
- self.transaction = self.conn.begin()
- await self.transaction.__aenter__()
-
- return self.conn
-
- async def __aexit__(
- self, type_: Any, value: Any, traceback: Any
- ) -> None:
- async def go() -> None:
- await self.transaction.__aexit__(type_, value, traceback)
- await self.conn.close()
-
- task = asyncio.create_task(go())
- await asyncio.shield(task)
-
def __init__(self, sync_engine: Engine):
if not sync_engine.dialect.is_async:
raise exc.InvalidRequestError(
@@ -903,7 +945,8 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine:
return AsyncEngine(target)
- def begin(self) -> AsyncEngine._trans_ctx:
+ @contextlib.asynccontextmanager
+ async def begin(self) -> AsyncIterator[AsyncConnection]:
"""Return a context manager which when entered will deliver an
:class:`_asyncio.AsyncConnection` with an
:class:`_asyncio.AsyncTransaction` established.
@@ -919,7 +962,10 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
"""
conn = self.connect()
- return self._trans_ctx(conn)
+
+ async with conn:
+ async with conn.begin():
+ yield conn
def connect(self) -> AsyncConnection:
"""Return an :class:`_asyncio.AsyncConnection` object.
@@ -1185,7 +1231,9 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
# END PROXY METHODS AsyncEngine
-class AsyncTransaction(ProxyComparable[Transaction], StartableContext):
+class AsyncTransaction(
+ ProxyComparable[Transaction], StartableContext["AsyncTransaction"]
+):
"""An asyncio proxy for a :class:`_engine.Transaction`."""
__slots__ = ("connection", "sync_transaction", "nested")
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
index 8f1c07fd8..41ead5ee2 100644
--- a/lib/sqlalchemy/ext/asyncio/result.py
+++ b/lib/sqlalchemy/ext/asyncio/result.py
@@ -47,11 +47,21 @@ class AsyncCommon(FilterResult[_R]):
_real_result: Result[Any]
_metadata: ResultMetaData
- async def close(self) -> None:
+ async def close(self) -> None: # type: ignore[override]
"""Close this result."""
await greenlet_spawn(self._real_result.close)
+ @property
+ def closed(self) -> bool:
+ """proxies the .closed attribute of the underlying result object,
+ if any, else raises ``AttributeError``.
+
+ .. versionadded:: 2.0.0b3
+
+ """
+ return self._real_result.closed # type: ignore
+
SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult[Any]")
@@ -96,6 +106,16 @@ class AsyncResult(AsyncCommon[Row[_TP]]):
)
@property
+ def closed(self) -> bool:
+ """proxies the .closed attribute of the underlying result object,
+ if any, else raises ``AttributeError``.
+
+ .. versionadded:: 2.0.0b3
+
+ """
+ return self._real_result.closed # type: ignore
+
+ @property
def t(self) -> AsyncTupleResult[_TP]:
"""Apply a "typed tuple" typing filter to returned rows.
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
index 0aa9661e9..1956ea588 100644
--- a/lib/sqlalchemy/ext/asyncio/session.py
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -1547,7 +1547,8 @@ class _AsyncSessionContextManager(Generic[_AS]):
class AsyncSessionTransaction(
- ReversibleProxy[SessionTransaction], StartableContext
+ ReversibleProxy[SessionTransaction],
+ StartableContext["AsyncSessionTransaction"],
):
"""A wrapper for the ORM :class:`_orm.SessionTransaction` object.