summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/base.py')
-rw-r--r--lib/sqlalchemy/ext/asyncio/base.py145
1 files changed, 138 insertions, 7 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py
index 7fdd2d7e0..13d5e40b2 100644
--- a/lib/sqlalchemy/ext/asyncio/base.py
+++ b/lib/sqlalchemy/ext/asyncio/base.py
@@ -10,8 +10,13 @@ from __future__ import annotations
import abc
import functools
from typing import Any
+from typing import AsyncGenerator
+from typing import AsyncIterator
+from typing import Awaitable
+from typing import Callable
from typing import ClassVar
from typing import Dict
+from typing import Generator
from typing import Generic
from typing import NoReturn
from typing import Optional
@@ -25,6 +30,7 @@ from ... import util
from ...util.typing import Literal
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_PT = TypeVar("_PT", bound=Any)
@@ -114,27 +120,29 @@ class ReversibleProxy(Generic[_PT]):
SelfStartableContext = TypeVar(
- "SelfStartableContext", bound="StartableContext"
+ "SelfStartableContext", bound="StartableContext[Any]"
)
-class StartableContext(abc.ABC):
+class StartableContext(Awaitable[_T_co], abc.ABC):
__slots__ = ()
@abc.abstractmethod
async def start(
self: SelfStartableContext, is_ctxmanager: bool = False
- ) -> Any:
+ ) -> _T_co:
raise NotImplementedError()
- def __await__(self) -> Any:
+ def __await__(self) -> Generator[Any, Any, _T_co]:
return self.start().__await__()
- async def __aenter__(self: SelfStartableContext) -> Any:
- return await self.start(is_ctxmanager=True)
+ async def __aenter__(self: SelfStartableContext) -> _T_co:
+ return await self.start(is_ctxmanager=True) # type: ignore
@abc.abstractmethod
- async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
+ async def __aexit__(
+ self, type_: Any, value: Any, traceback: Any
+ ) -> Optional[bool]:
pass
def _raise_for_not_started(self) -> NoReturn:
@@ -144,6 +152,129 @@ class StartableContext(abc.ABC):
)
+class GeneratorStartableContext(StartableContext[_T_co]):
+ __slots__ = ("gen",)
+
+ gen: AsyncGenerator[_T_co, Any]
+
+ def __init__(
+ self,
+ func: Callable[..., AsyncIterator[_T_co]],
+ args: tuple[Any, ...],
+ kwds: dict[str, Any],
+ ):
+ self.gen = func(*args, **kwds) # type: ignore
+
+ async def start(self, is_ctxmanager: bool = False) -> _T_co:
+ try:
+ start_value = await util.anext_(self.gen)
+ except StopAsyncIteration:
+ raise RuntimeError("generator didn't yield") from None
+
+ # if not a context manager, then interrupt the generator, don't
+ # let it complete. this step is technically not needed, as the
+ # generator will close in any case at gc time. not clear if having
+ # this here is a good idea or not (though it helps for clarity IMO)
+ if not is_ctxmanager:
+ await self.gen.aclose()
+
+ return start_value
+
+ async def __aexit__(
+ self, typ: Any, value: Any, traceback: Any
+ ) -> Optional[bool]:
+ # vendored from contextlib.py
+ if typ is None:
+ try:
+ await util.anext_(self.gen)
+ except StopAsyncIteration:
+ return False
+ else:
+ raise RuntimeError("generator didn't stop")
+ else:
+ if value is None:
+ # Need to force instantiation so we can reliably
+ # tell if we get the same exception back
+ value = typ()
+ try:
+ await self.gen.athrow(typ, value, traceback)
+ except StopAsyncIteration as exc:
+ # Suppress StopIteration *unless* it's the same exception that
+ # was passed to throw(). This prevents a StopIteration
+ # raised inside the "with" statement from being suppressed.
+ return exc is not value
+ except RuntimeError as exc:
+ # Don't re-raise the passed in exception. (issue27122)
+ if exc is value:
+ return False
+ # Avoid suppressing if a Stop(Async)Iteration exception
+ # was passed to athrow() and later wrapped into a RuntimeError
+ # (see PEP 479 for sync generators; async generators also
+ # have this behavior). But do this only if the exception
+ # wrapped
+ # by the RuntimeError is actully Stop(Async)Iteration (see
+ # issue29692).
+ if (
+ isinstance(value, (StopIteration, StopAsyncIteration))
+ and exc.__cause__ is value
+ ):
+ return False
+ raise
+ except BaseException as exc:
+ # only re-raise if it's *not* the exception that was
+ # passed to throw(), because __exit__() must not raise
+ # an exception unless __exit__() itself failed. But throw()
+ # has to raise the exception to signal propagation, so this
+ # fixes the impedance mismatch between the throw() protocol
+ # and the __exit__() protocol.
+ if exc is not value:
+ raise
+ return False
+ raise RuntimeError("generator didn't stop after athrow()")
+
+
+def asyncstartablecontext(
+ func: Callable[..., AsyncIterator[_T_co]]
+) -> Callable[..., GeneratorStartableContext[_T_co]]:
+ """@asyncstartablecontext decorator.
+
+ the decorated function can be called either as ``async with fn()``, **or**
+ ``await fn()``. This is decidedly different from what
+ ``@contextlib.asynccontextmanager`` supports, and the usage pattern
+ is different as well.
+
+ Typical usage::
+
+ @asyncstartablecontext
+ async def some_async_generator(<arguments>):
+ <setup>
+ try:
+ yield <value>
+ except GeneratorExit:
+ # return value was awaited, no context manager is present
+ # and caller will .close() the resource explicitly
+ pass
+ else:
+ <context manager cleanup>
+
+
+ Above, ``GeneratorExit`` is caught if the function were used as an
+ ``await``. In this case, it's essential that the cleanup does **not**
+ occur, so there should not be a ``finally`` block.
+
+ If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__``
+ and we were invoked as a context manager, and cleanup should proceed.
+
+
+ """
+
+ @functools.wraps(func)
+ def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]:
+ return GeneratorStartableContext(func, args, kwds)
+
+ return helper
+
+
class ProxyComparable(ReversibleProxy[_PT]):
__slots__ = ()