diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/engine/result.py | 59 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/base.py | 145 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 162 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/result.py | 22 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/session.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/util/__init__.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/util/compat.py | 23 |
8 files changed, 355 insertions, 69 deletions
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 05ca17063..bcd2f0ea9 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -928,6 +928,12 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): def __init__(self, cursor_metadata: ResultMetaData): self._metadata = cursor_metadata + def __enter__(self) -> Result[_TP]: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + def close(self) -> None: """close this :class:`_result.Result`. @@ -950,6 +956,19 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): """ self._soft_close(hard=True) + @property + def _soft_closed(self) -> bool: + raise NotImplementedError() + + @property + def closed(self) -> bool: + """return True if this :class:`.Result` reports .closed + + .. versionadded:: 1.4.43 + + """ + raise NotImplementedError() + @_generative def yield_per(self: SelfResult, num: int) -> SelfResult: """Configure the row-fetching strategy to fetch ``num`` rows at a time. @@ -1574,6 +1593,12 @@ class FilterResult(ResultInternal[_R]): _real_result: Result[Any] + def __enter__(self: SelfFilterResult) -> SelfFilterResult: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self._real_result.__exit__(type_, value, traceback) + @_generative def yield_per(self: SelfFilterResult, num: int) -> SelfFilterResult: """Configure the row-fetching strategy to fetch ``num`` rows at a time. @@ -1600,6 +1625,27 @@ class FilterResult(ResultInternal[_R]): self._real_result._soft_close(hard=hard) @property + def _soft_closed(self) -> bool: + return self._real_result._soft_closed + + @property + def closed(self) -> bool: + """return True if the underlying :class:`.Result` reports .closed + + .. versionadded:: 1.4.43 + + """ + return self._real_result.closed + + def close(self) -> None: + """Close this :class:`.FilterResult`. + + .. versionadded:: 1.4.43 + + """ + self._real_result.close() + + @property def _attributes(self) -> Dict[Any, Any]: return self._real_result._attributes @@ -2172,7 +2218,7 @@ class IteratorResult(Result[_TP]): self, cursor_metadata: ResultMetaData, iterator: Iterator[_InterimSupportsScalarsRowType], - raw: Optional[Any] = None, + raw: Optional[Result[Any]] = None, _source_supports_scalars: bool = False, ): self._metadata = cursor_metadata @@ -2180,6 +2226,15 @@ class IteratorResult(Result[_TP]): self.raw = raw self._source_supports_scalars = _source_supports_scalars + @property + def closed(self) -> bool: + """return True if this :class:`.IteratorResult` has been closed + + .. versionadded:: 1.4.43 + + """ + return self._hard_closed + def _soft_close(self, hard: bool = False, **kw: Any) -> None: if hard: self._hard_closed = True @@ -2262,7 +2317,7 @@ class ChunkedIteratorResult(IteratorResult[_TP]): [Optional[int]], Iterator[Sequence[_InterimRowType[_R]]] ], source_supports_scalars: bool = False, - raw: Optional[Any] = None, + raw: Optional[Result[Any]] = None, dynamic_yield_per: bool = False, ): self._metadata = cursor_metadata diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 7fdd2d7e0..13d5e40b2 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -10,8 +10,13 @@ from __future__ import annotations import abc import functools from typing import Any +from typing import AsyncGenerator +from typing import AsyncIterator +from typing import Awaitable +from typing import Callable from typing import ClassVar from typing import Dict +from typing import Generator from typing import Generic from typing import NoReturn from typing import Optional @@ -25,6 +30,7 @@ from ... import util from ...util.typing import Literal _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _PT = TypeVar("_PT", bound=Any) @@ -114,27 +120,29 @@ class ReversibleProxy(Generic[_PT]): SelfStartableContext = TypeVar( - "SelfStartableContext", bound="StartableContext" + "SelfStartableContext", bound="StartableContext[Any]" ) -class StartableContext(abc.ABC): +class StartableContext(Awaitable[_T_co], abc.ABC): __slots__ = () @abc.abstractmethod async def start( self: SelfStartableContext, is_ctxmanager: bool = False - ) -> Any: + ) -> _T_co: raise NotImplementedError() - def __await__(self) -> Any: + def __await__(self) -> Generator[Any, Any, _T_co]: return self.start().__await__() - async def __aenter__(self: SelfStartableContext) -> Any: - return await self.start(is_ctxmanager=True) + async def __aenter__(self: SelfStartableContext) -> _T_co: + return await self.start(is_ctxmanager=True) # type: ignore @abc.abstractmethod - async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + async def __aexit__( + self, type_: Any, value: Any, traceback: Any + ) -> Optional[bool]: pass def _raise_for_not_started(self) -> NoReturn: @@ -144,6 +152,129 @@ class StartableContext(abc.ABC): ) +class GeneratorStartableContext(StartableContext[_T_co]): + __slots__ = ("gen",) + + gen: AsyncGenerator[_T_co, Any] + + def __init__( + self, + func: Callable[..., AsyncIterator[_T_co]], + args: tuple[Any, ...], + kwds: dict[str, Any], + ): + self.gen = func(*args, **kwds) # type: ignore + + async def start(self, is_ctxmanager: bool = False) -> _T_co: + try: + start_value = await util.anext_(self.gen) + except StopAsyncIteration: + raise RuntimeError("generator didn't yield") from None + + # if not a context manager, then interrupt the generator, don't + # let it complete. this step is technically not needed, as the + # generator will close in any case at gc time. not clear if having + # this here is a good idea or not (though it helps for clarity IMO) + if not is_ctxmanager: + await self.gen.aclose() + + return start_value + + async def __aexit__( + self, typ: Any, value: Any, traceback: Any + ) -> Optional[bool]: + # vendored from contextlib.py + if typ is None: + try: + await util.anext_(self.gen) + except StopAsyncIteration: + return False + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + # Need to force instantiation so we can reliably + # tell if we get the same exception back + value = typ() + try: + await self.gen.athrow(typ, value, traceback) + except StopAsyncIteration as exc: + # Suppress StopIteration *unless* it's the same exception that + # was passed to throw(). This prevents a StopIteration + # raised inside the "with" statement from being suppressed. + return exc is not value + except RuntimeError as exc: + # Don't re-raise the passed in exception. (issue27122) + if exc is value: + return False + # Avoid suppressing if a Stop(Async)Iteration exception + # was passed to athrow() and later wrapped into a RuntimeError + # (see PEP 479 for sync generators; async generators also + # have this behavior). But do this only if the exception + # wrapped + # by the RuntimeError is actully Stop(Async)Iteration (see + # issue29692). + if ( + isinstance(value, (StopIteration, StopAsyncIteration)) + and exc.__cause__ is value + ): + return False + raise + except BaseException as exc: + # only re-raise if it's *not* the exception that was + # passed to throw(), because __exit__() must not raise + # an exception unless __exit__() itself failed. But throw() + # has to raise the exception to signal propagation, so this + # fixes the impedance mismatch between the throw() protocol + # and the __exit__() protocol. + if exc is not value: + raise + return False + raise RuntimeError("generator didn't stop after athrow()") + + +def asyncstartablecontext( + func: Callable[..., AsyncIterator[_T_co]] +) -> Callable[..., GeneratorStartableContext[_T_co]]: + """@asyncstartablecontext decorator. + + the decorated function can be called either as ``async with fn()``, **or** + ``await fn()``. This is decidedly different from what + ``@contextlib.asynccontextmanager`` supports, and the usage pattern + is different as well. + + Typical usage:: + + @asyncstartablecontext + async def some_async_generator(<arguments>): + <setup> + try: + yield <value> + except GeneratorExit: + # return value was awaited, no context manager is present + # and caller will .close() the resource explicitly + pass + else: + <context manager cleanup> + + + Above, ``GeneratorExit`` is caught if the function were used as an + ``await``. In this case, it's essential that the cleanup does **not** + occur, so there should not be a ``finally`` block. + + If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__`` + and we were invoked as a context manager, and cleanup should proceed. + + + """ + + @functools.wraps(func) + def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]: + return GeneratorStartableContext(func, args, kwds) + + return helper + + class ProxyComparable(ReversibleProxy[_PT]): __slots__ = () diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 389088875..854039c98 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -7,7 +7,9 @@ from __future__ import annotations import asyncio +import contextlib from typing import Any +from typing import AsyncIterator from typing import Callable from typing import Dict from typing import Generator @@ -21,6 +23,8 @@ from typing import TypeVar from typing import Union from . import exc as async_exc +from .base import asyncstartablecontext +from .base import GeneratorStartableContext from .base import ProxyComparable from .base import StartableContext from .result import _ensure_sync_result @@ -133,7 +137,9 @@ class AsyncConnectable: ], ) class AsyncConnection( - ProxyComparable[Connection], StartableContext, AsyncConnectable + ProxyComparable[Connection], + StartableContext["AsyncConnection"], + AsyncConnectable, ): """An asyncio proxy for a :class:`_engine.Connection`. @@ -446,34 +452,67 @@ class AsyncConnection( return await _ensure_sync_result(result, self.exec_driver_sql) @overload - async def stream( + def stream( self, statement: TypedReturnsRows[_T], parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncResult[_T]: + ) -> GeneratorStartableContext[AsyncResult[_T]]: ... @overload - async def stream( + def stream( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncResult[Any]: + ) -> GeneratorStartableContext[AsyncResult[Any]]: ... + @asyncstartablecontext async def stream( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncResult[Any]: - """Execute a statement and return a streaming - :class:`_asyncio.AsyncResult` object.""" + ) -> AsyncIterator[AsyncResult[Any]]: + """Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncResult` object. + + E.g.:: + + result = await conn.stream(stmt): + async for row in result: + print(f"{row}") + + The :meth:`.AsyncConnection.stream` + method supports optional context manager use against the + :class:`.AsyncResult` object, as in:: + + async with conn.stream(stmt) as result: + async for row in result: + print(f"{row}") + + In the above pattern, the :meth:`.AsyncResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncResult` object. + + .. seealso:: + + :meth:`.AsyncConnection.stream_scalars` + + """ result = await greenlet_spawn( self._proxied.execute, @@ -484,10 +523,15 @@ class AsyncConnection( ), _require_await=True, ) - if not result.context._is_server_side: - # TODO: real exception here - assert False, "server side result expected" - return AsyncResult(result) + assert result.context._is_server_side + ar = AsyncResult(result) + try: + yield ar + except GeneratorExit: + pass + else: + task = asyncio.create_task(ar.close()) + await asyncio.shield(task) @overload async def execute( @@ -642,48 +686,77 @@ class AsyncConnection( return result.scalars() @overload - async def stream_scalars( + def stream_scalars( self, statement: TypedReturnsRows[Tuple[_T]], parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncScalarResult[_T]: + ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: ... @overload - async def stream_scalars( + def stream_scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncScalarResult[Any]: + ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: ... + @asyncstartablecontext async def stream_scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> AsyncScalarResult[Any]: - r"""Executes a SQL statement and returns a streaming scalar result - object. + ) -> AsyncIterator[AsyncScalarResult[Any]]: + r"""Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncScalarResult` object. + + E.g.:: + + result = await conn.stream_scalars(stmt): + async for scalar in result: + print(f"{scalar}") This method is shorthand for invoking the :meth:`_engine.AsyncResult.scalars` method after invoking the :meth:`_engine.Connection.stream` method. Parameters are equivalent. - :return: an :class:`_asyncio.AsyncScalarResult` object. + The :meth:`.AsyncConnection.stream_scalars` + method supports optional context manager use against the + :class:`.AsyncScalarResult` object, as in:: + + async with conn.stream_scalars(stmt) as result: + async for scalar in result: + print(f"{scalar}") + + In the above pattern, the :meth:`.AsyncScalarResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncScalarResult` object. .. versionadded:: 1.4.24 + .. seealso:: + + :meth:`.AsyncConnection.stream` + """ - result = await self.stream( + + async with self.stream( statement, parameters, execution_options=execution_options - ) - return result.scalars() + ) as result: + yield result.scalars() async def run_sync( self, fn: Callable[..., Any], *arg: Any, **kw: Any @@ -856,37 +929,6 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): :ref:`asyncio_events` """ - class _trans_ctx(StartableContext): - __slots__ = ("conn", "transaction") - - conn: AsyncConnection - transaction: AsyncTransaction - - def __init__(self, conn: AsyncConnection): - self.conn = conn - - if TYPE_CHECKING: - - async def __aenter__(self) -> AsyncConnection: - ... - - async def start(self, is_ctxmanager: bool = False) -> AsyncConnection: - await self.conn.start(is_ctxmanager=is_ctxmanager) - self.transaction = self.conn.begin() - await self.transaction.__aenter__() - - return self.conn - - async def __aexit__( - self, type_: Any, value: Any, traceback: Any - ) -> None: - async def go() -> None: - await self.transaction.__aexit__(type_, value, traceback) - await self.conn.close() - - task = asyncio.create_task(go()) - await asyncio.shield(task) - def __init__(self, sync_engine: Engine): if not sync_engine.dialect.is_async: raise exc.InvalidRequestError( @@ -903,7 +945,8 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: return AsyncEngine(target) - def begin(self) -> AsyncEngine._trans_ctx: + @contextlib.asynccontextmanager + async def begin(self) -> AsyncIterator[AsyncConnection]: """Return a context manager which when entered will deliver an :class:`_asyncio.AsyncConnection` with an :class:`_asyncio.AsyncTransaction` established. @@ -919,7 +962,10 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): """ conn = self.connect() - return self._trans_ctx(conn) + + async with conn: + async with conn.begin(): + yield conn def connect(self) -> AsyncConnection: """Return an :class:`_asyncio.AsyncConnection` object. @@ -1185,7 +1231,9 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): # END PROXY METHODS AsyncEngine -class AsyncTransaction(ProxyComparable[Transaction], StartableContext): +class AsyncTransaction( + ProxyComparable[Transaction], StartableContext["AsyncTransaction"] +): """An asyncio proxy for a :class:`_engine.Transaction`.""" __slots__ = ("connection", "sync_transaction", "nested") diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index 8f1c07fd8..41ead5ee2 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -47,11 +47,21 @@ class AsyncCommon(FilterResult[_R]): _real_result: Result[Any] _metadata: ResultMetaData - async def close(self) -> None: + async def close(self) -> None: # type: ignore[override] """Close this result.""" await greenlet_spawn(self._real_result.close) + @property + def closed(self) -> bool: + """proxies the .closed attribute of the underlying result object, + if any, else raises ``AttributeError``. + + .. versionadded:: 2.0.0b3 + + """ + return self._real_result.closed # type: ignore + SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult[Any]") @@ -96,6 +106,16 @@ class AsyncResult(AsyncCommon[Row[_TP]]): ) @property + def closed(self) -> bool: + """proxies the .closed attribute of the underlying result object, + if any, else raises ``AttributeError``. + + .. versionadded:: 2.0.0b3 + + """ + return self._real_result.closed # type: ignore + + @property def t(self) -> AsyncTupleResult[_TP]: """Apply a "typed tuple" typing filter to returned rows. diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 0aa9661e9..1956ea588 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -1547,7 +1547,8 @@ class _AsyncSessionContextManager(Generic[_AS]): class AsyncSessionTransaction( - ReversibleProxy[SessionTransaction], StartableContext + ReversibleProxy[SessionTransaction], + StartableContext["AsyncSessionTransaction"], ): """A wrapper for the ORM :class:`_orm.SessionTransaction` object. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 63035f585..9ac6d07da 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2713,7 +2713,14 @@ class Query( return None def __iter__(self) -> Iterable[_T]: - return self._iter().__iter__() # type: ignore + result = self._iter() + try: + yield from result + except GeneratorExit: + # issue #8710 - direct iteration is not re-usable after + # an iterable block is broken, so close the result + result._soft_close() + raise def _iter(self) -> Union[ScalarResult[_T], Result[_T]]: # new style execution. diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index cd7e0fd81..4952cb501 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -46,6 +46,7 @@ from ._collections import UniqueAppender as UniqueAppender from ._collections import update_copy as update_copy from ._collections import WeakPopulateDict as WeakPopulateDict from ._collections import WeakSequence as WeakSequence +from .compat import anext_ as anext_ from .compat import arm as arm from .compat import b as b from .compat import b64decode as b64decode diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 4ce1e7ff3..cda5ab6c1 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -111,6 +111,29 @@ else: return a +if py310: + anext_ = anext +else: + + _NOT_PROVIDED = object() + from collections.abc import AsyncIterator + + async def anext_(async_iterator, default=_NOT_PROVIDED): + """vendored from https://github.com/python/cpython/pull/8895""" + + if not isinstance(async_iterator, AsyncIterator): + raise TypeError( + f"anext expected an AsyncIterator, got {type(async_iterator)}" + ) + anxt = type(async_iterator).__anext__ + try: + return await anxt(async_iterator) + except StopAsyncIteration: + if default is _NOT_PROVIDED: + raise + return default + + def importlib_metadata_get(group): ep = importlib_metadata.entry_points() if not typing.TYPE_CHECKING and hasattr(ep, "select"): |