diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-10 15:42:35 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-11 22:11:07 -0400 |
commit | a45e2284dad17fbbba3bea9d5e5304aab21c8c94 (patch) | |
tree | ac31614f2d53059570e2edffe731baf384baea23 /lib/sqlalchemy/ext/asyncio | |
parent | aa9cd878e8249a4a758c7f968e929e92fede42a5 (diff) | |
download | sqlalchemy-a45e2284dad17fbbba3bea9d5e5304aab21c8c94.tar.gz |
pep-484: asyncio
in this patch the asyncio/events.py module, which
existed only to raise errors when trying to attach event
listeners, is removed, as we were already coding an asyncio-specific
workaround in upstream Pool / Session to raise this error,
just moved the error out to the target and did the same thing
for Engine.
We also add an async_sessionmaker class. The initial rationale
here is because sessionmaker() is hardcoded to Session subclasses,
and there's not a way to get the use case of
sessionmaker(class_=AsyncSession) to type correctly without changing
the sessionmaker() symbol itself to be a function and not a class,
which gets too complicated for what this is. Additionally,
_SessionClassMethods has only three methods on it, one of which
is not usable with asyncio (close_all()), the others
not generally used from the session class.
Change-Id: I064a5fa5d91cc8d5bbe9597437536e37b4e801fe
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio')
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/__init__.py | 29 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/base.py | 118 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 401 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/events.py | 44 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/result.py | 180 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/scoping.py | 241 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/session.py | 406 |
7 files changed, 956 insertions, 463 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index 15b2cb015..dfe89a154 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -5,18 +5,17 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -from .engine import async_engine_from_config -from .engine import AsyncConnection -from .engine import AsyncEngine -from .engine import AsyncTransaction -from .engine import create_async_engine -from .events import AsyncConnectionEvents -from .events import AsyncSessionEvents -from .result import AsyncMappingResult -from .result import AsyncResult -from .result import AsyncScalarResult -from .scoping import async_scoped_session -from .session import async_object_session -from .session import async_session -from .session import AsyncSession -from .session import AsyncSessionTransaction +from .engine import async_engine_from_config as async_engine_from_config +from .engine import AsyncConnection as AsyncConnection +from .engine import AsyncEngine as AsyncEngine +from .engine import AsyncTransaction as AsyncTransaction +from .engine import create_async_engine as create_async_engine +from .result import AsyncMappingResult as AsyncMappingResult +from .result import AsyncResult as AsyncResult +from .result import AsyncScalarResult as AsyncScalarResult +from .scoping import async_scoped_session as async_scoped_session +from .session import async_object_session as async_object_session +from .session import async_session as async_session +from .session import async_sessionmaker as async_sessionmaker +from .session import AsyncSession as AsyncSession +from .session import AsyncSessionTransaction as AsyncSessionTransaction diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 3f77f5500..7fdd2d7e0 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -1,36 +1,103 @@ +# ext/asyncio/base.py +# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + import abc import functools +from typing import Any +from typing import ClassVar +from typing import Dict +from typing import Generic +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Type +from typing import TypeVar import weakref from . import exc as async_exc +from ... import util +from ...util.typing import Literal + +_T = TypeVar("_T", bound=Any) + + +_PT = TypeVar("_PT", bound=Any) -class ReversibleProxy: - # weakref.ref(async proxy object) -> weakref.ref(sync proxied object) - _proxy_objects = {} +SelfReversibleProxy = TypeVar( + "SelfReversibleProxy", bound="ReversibleProxy[Any]" +) + + +class ReversibleProxy(Generic[_PT]): + _proxy_objects: ClassVar[ + Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]] + ] = {} __slots__ = ("__weakref__",) - def _assign_proxied(self, target): + @overload + def _assign_proxied(self, target: _PT) -> _PT: + ... + + @overload + def _assign_proxied(self, target: None) -> None: + ... + + def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]: if target is not None: - target_ref = weakref.ref(target, ReversibleProxy._target_gced) + target_ref: weakref.ref[_PT] = weakref.ref( + target, ReversibleProxy._target_gced + ) proxy_ref = weakref.ref( self, - functools.partial(ReversibleProxy._target_gced, target_ref), + functools.partial( # type: ignore + ReversibleProxy._target_gced, target_ref + ), ) ReversibleProxy._proxy_objects[target_ref] = proxy_ref return target @classmethod - def _target_gced(cls, ref, proxy_ref=None): + def _target_gced( + cls: Type[SelfReversibleProxy], + ref: weakref.ref[_PT], + proxy_ref: Optional[weakref.ref[SelfReversibleProxy]] = None, + ) -> None: cls._proxy_objects.pop(ref, None) @classmethod - def _regenerate_proxy_for_target(cls, target): + def _regenerate_proxy_for_target( + cls: Type[SelfReversibleProxy], target: _PT + ) -> SelfReversibleProxy: raise NotImplementedError() + @overload @classmethod - def _retrieve_proxy_for_target(cls, target, regenerate=True): + def _retrieve_proxy_for_target( + cls: Type[SelfReversibleProxy], + target: _PT, + regenerate: Literal[True] = ..., + ) -> SelfReversibleProxy: + ... + + @overload + @classmethod + def _retrieve_proxy_for_target( + cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True + ) -> Optional[SelfReversibleProxy]: + ... + + @classmethod + def _retrieve_proxy_for_target( + cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True + ) -> Optional[SelfReversibleProxy]: try: proxy_ref = cls._proxy_objects[weakref.ref(target)] except KeyError: @@ -38,7 +105,7 @@ class ReversibleProxy: else: proxy = proxy_ref() if proxy is not None: - return proxy + return proxy # type: ignore if regenerate: return cls._regenerate_proxy_for_target(target) @@ -46,43 +113,54 @@ class ReversibleProxy: return None +SelfStartableContext = TypeVar( + "SelfStartableContext", bound="StartableContext" +) + + class StartableContext(abc.ABC): __slots__ = () @abc.abstractmethod - async def start(self, is_ctxmanager=False): - pass + async def start( + self: SelfStartableContext, is_ctxmanager: bool = False + ) -> Any: + raise NotImplementedError() - def __await__(self): + def __await__(self) -> Any: return self.start().__await__() - async def __aenter__(self): + async def __aenter__(self: SelfStartableContext) -> Any: return await self.start(is_ctxmanager=True) @abc.abstractmethod - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: pass - def _raise_for_not_started(self): + def _raise_for_not_started(self) -> NoReturn: raise async_exc.AsyncContextNotStarted( "%s context has not been started and object has not been awaited." % (self.__class__.__name__) ) -class ProxyComparable(ReversibleProxy): +class ProxyComparable(ReversibleProxy[_PT]): __slots__ = () - def __hash__(self): + @util.ro_non_memoized_property + def _proxied(self) -> _PT: + raise NotImplementedError() + + def __hash__(self) -> int: return id(self) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, self.__class__) and self._proxied == other._proxied ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return ( not isinstance(other, self.__class__) or self._proxied != other._proxied diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 3b54405c1..bb51a4d22 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -7,23 +7,56 @@ from __future__ import annotations from typing import Any +from typing import Dict +from typing import Generator +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Type +from typing import TYPE_CHECKING +from typing import Union from . import exc as async_exc from .base import ProxyComparable from .base import StartableContext from .result import _ensure_sync_result from .result import AsyncResult +from .result import AsyncScalarResult from ... import exc from ... import inspection from ... import util +from ...engine import Connection from ...engine import create_engine as _create_engine +from ...engine import Engine from ...engine.base import NestedTransaction -from ...future import Connection -from ...future import Engine +from ...engine.base import Transaction from ...util.concurrency import greenlet_spawn - - -def create_async_engine(*arg, **kw): +from ...util.typing import Protocol + +if TYPE_CHECKING: + from ...engine import Connection + from ...engine import Engine + from ...engine.cursor import CursorResult + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.interfaces import _DBAPIAnyExecuteParams + from ...engine.interfaces import _ExecuteOptions + from ...engine.interfaces import _ExecuteOptionsParameter + from ...engine.interfaces import _IsolationLevel + from ...engine.interfaces import Dialect + from ...engine.result import ScalarResult + from ...engine.url import URL + from ...pool import Pool + from ...pool import PoolProxiedConnection + from ...sql.base import Executable + + +class _SyncConnectionCallable(Protocol): + def __call__(self, connection: Connection, *arg: Any, **kw: Any) -> Any: + ... + + +def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: """Create a new async engine instance. Arguments passed to :func:`_asyncio.create_async_engine` are mostly @@ -43,11 +76,13 @@ def create_async_engine(*arg, **kw): ) kw["future"] = True kw["_is_async"] = True - sync_engine = _create_engine(*arg, **kw) + sync_engine = _create_engine(url, **kw) return AsyncEngine(sync_engine) -def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs): +def async_engine_from_config( + configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any +) -> AsyncEngine: """Create a new AsyncEngine instance using a configuration dictionary. This function is analogous to the :func:`_sa.engine_from_config` function @@ -73,6 +108,14 @@ def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs): class AsyncConnectable: __slots__ = "_slots_dispatch", "__weakref__" + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes." + ) + @util.create_proxy_methods( Connection, @@ -87,7 +130,9 @@ class AsyncConnectable: "default_isolation_level", ], ) -class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): +class AsyncConnection( + ProxyComparable[Connection], StartableContext, AsyncConnectable +): """An asyncio proxy for a :class:`_engine.Connection`. :class:`_asyncio.AsyncConnection` is acquired using the @@ -115,12 +160,16 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): "sync_connection", ) - def __init__(self, async_engine, sync_connection=None): + def __init__( + self, + async_engine: AsyncEngine, + sync_connection: Optional[Connection] = None, + ): self.engine = async_engine self.sync_engine = async_engine.sync_engine self.sync_connection = self._assign_proxied(sync_connection) - sync_connection: Connection + sync_connection: Optional[Connection] """Reference to the sync-style :class:`_engine.Connection` this :class:`_asyncio.AsyncConnection` proxies requests towards. @@ -146,12 +195,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ @classmethod - def _regenerate_proxy_for_target(cls, target): + def _regenerate_proxy_for_target( + cls, target: Connection + ) -> AsyncConnection: return AsyncConnection( AsyncEngine._retrieve_proxy_for_target(target.engine), target ) - async def start(self, is_ctxmanager=False): + async def start(self, is_ctxmanager: bool = False) -> AsyncConnection: """Start this :class:`_asyncio.AsyncConnection` object's context outside of using a Python ``with:`` block. @@ -164,7 +215,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): return self @property - def connection(self): + def connection(self) -> NoReturn: """Not implemented for async; call :meth:`_asyncio.AsyncConnection.get_raw_connection`. """ @@ -174,7 +225,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): "Use the get_raw_connection() method." ) - async def get_raw_connection(self): + async def get_raw_connection(self) -> PoolProxiedConnection: """Return the pooled DBAPI-level connection in use by this :class:`_asyncio.AsyncConnection`. @@ -187,16 +238,11 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): adapts the driver connection to the DBAPI protocol. """ - conn = self._sync_connection() - - return await greenlet_spawn(getattr, conn, "connection") - @property - def _proxied(self): - return self.sync_connection + return await greenlet_spawn(getattr, self._proxied, "connection") @property - def info(self): + def info(self) -> Dict[str, Any]: """Return the :attr:`_engine.Connection.info` dictionary of the underlying :class:`_engine.Connection`. @@ -211,24 +257,28 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): .. versionadded:: 1.4.0b2 """ - return self.sync_connection.info + return self._proxied.info - def _sync_connection(self): + @util.ro_non_memoized_property + def _proxied(self) -> Connection: if not self.sync_connection: self._raise_for_not_started() return self.sync_connection - def begin(self): + def begin(self) -> AsyncTransaction: """Begin a transaction prior to autobegin occurring.""" - self._sync_connection() + assert self._proxied return AsyncTransaction(self) - def begin_nested(self): + def begin_nested(self) -> AsyncTransaction: """Begin a nested transaction and return a transaction handle.""" - self._sync_connection() + assert self._proxied return AsyncTransaction(self, nested=True) - async def invalidate(self, exception=None): + async def invalidate( + self, exception: Optional[BaseException] = None + ) -> None: + """Invalidate the underlying DBAPI connection associated with this :class:`_engine.Connection`. @@ -237,39 +287,27 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ - conn = self._sync_connection() - return await greenlet_spawn(conn.invalidate, exception=exception) - - async def get_isolation_level(self): - conn = self._sync_connection() - return await greenlet_spawn(conn.get_isolation_level) - - async def set_isolation_level(self): - conn = self._sync_connection() - return await greenlet_spawn(conn.get_isolation_level) - - def in_transaction(self): - """Return True if a transaction is in progress. - - .. versionadded:: 1.4.0b2 + return await greenlet_spawn( + self._proxied.invalidate, exception=exception + ) - """ + async def get_isolation_level(self) -> _IsolationLevel: + return await greenlet_spawn(self._proxied.get_isolation_level) - conn = self._sync_connection() + def in_transaction(self) -> bool: + """Return True if a transaction is in progress.""" - return conn.in_transaction() + return self._proxied.in_transaction() - def in_nested_transaction(self): + def in_nested_transaction(self) -> bool: """Return True if a transaction is in progress. .. versionadded:: 1.4.0b2 """ - conn = self._sync_connection() - - return conn.in_nested_transaction() + return self._proxied.in_nested_transaction() - def get_transaction(self): + def get_transaction(self) -> Optional[AsyncTransaction]: """Return an :class:`.AsyncTransaction` representing the current transaction, if any. @@ -281,15 +319,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): .. versionadded:: 1.4.0b2 """ - conn = self._sync_connection() - trans = conn.get_transaction() + trans = self._proxied.get_transaction() if trans is not None: return AsyncTransaction._retrieve_proxy_for_target(trans) else: return None - def get_nested_transaction(self): + def get_nested_transaction(self) -> Optional[AsyncTransaction]: """Return an :class:`.AsyncTransaction` representing the current nested (savepoint) transaction, if any. @@ -301,15 +338,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): .. versionadded:: 1.4.0b2 """ - conn = self._sync_connection() - trans = conn.get_nested_transaction() + trans = self._proxied.get_nested_transaction() if trans is not None: return AsyncTransaction._retrieve_proxy_for_target(trans) else: return None - async def execution_options(self, **opt): + async def execution_options(self, **opt: Any) -> AsyncConnection: r"""Set non-SQL options for the connection which take effect during execution. @@ -321,12 +357,12 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ - conn = self._sync_connection() + conn = self._proxied c2 = await greenlet_spawn(conn.execution_options, **opt) assert c2 is conn return self - async def commit(self): + async def commit(self) -> None: """Commit the transaction that is currently in progress. This method commits the current transaction if one has been started. @@ -338,10 +374,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): :meth:`_future.Connection.begin` method is called. """ - conn = self._sync_connection() - await greenlet_spawn(conn.commit) + await greenlet_spawn(self._proxied.commit) - async def rollback(self): + async def rollback(self) -> None: """Roll back the transaction that is currently in progress. This method rolls back the current transaction if one has been started. @@ -355,34 +390,30 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ - conn = self._sync_connection() - await greenlet_spawn(conn.rollback) + await greenlet_spawn(self._proxied.rollback) - async def close(self): + async def close(self) -> None: """Close this :class:`_asyncio.AsyncConnection`. This has the effect of also rolling back the transaction if one is in place. """ - conn = self._sync_connection() - await greenlet_spawn(conn.close) + await greenlet_spawn(self._proxied.close) async def exec_driver_sql( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: str, + parameters: Optional[_DBAPIAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult: r"""Executes a driver-level SQL string and return buffered :class:`_engine.Result`. """ - conn = self._sync_connection() - result = await greenlet_spawn( - conn.exec_driver_sql, + self._proxied.exec_driver_sql, statement, parameters, execution_options, @@ -393,17 +424,15 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): async def stream( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncResult: """Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object.""" - conn = self._sync_connection() - result = await greenlet_spawn( - conn.execute, + self._proxied.execute, statement, parameters, util.EMPTY_DICT.merge_with( @@ -418,10 +447,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): async def execute( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult: r"""Executes a SQL statement construct and return a buffered :class:`_engine.Result`. @@ -453,10 +482,8 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): :return: a :class:`_engine.Result` object. """ - conn = self._sync_connection() - result = await greenlet_spawn( - conn.execute, + self._proxied.execute, statement, parameters, execution_options, @@ -466,10 +493,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): async def scalar( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Any: r"""Executes a SQL statement construct and returns a scalar object. This method is shorthand for invoking the @@ -485,10 +512,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): async def scalars( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: r"""Executes a SQL statement construct and returns a scalar objects. This method is shorthand for invoking the @@ -505,10 +532,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): async def stream_scalars( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncScalarResult[Any]: r"""Executes a SQL statement and returns a streaming scalar result object. @@ -524,7 +551,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): result = await self.stream(statement, parameters, execution_options) return result.scalars() - async def run_sync(self, fn, *arg, **kw): + async def run_sync( + self, fn: _SyncConnectionCallable, *arg: Any, **kw: Any + ) -> Any: """Invoke the given sync callable passing self as the first argument. This method maintains the asyncio event loop all the way through @@ -548,14 +577,12 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): :ref:`session_run_sync` """ - conn = self._sync_connection() - - return await greenlet_spawn(fn, conn, *arg, **kw) + return await greenlet_spawn(fn, self._proxied, *arg, **kw) - def __await__(self): + def __await__(self) -> Generator[Any, None, AsyncConnection]: return self.start().__await__() - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.close() # START PROXY METHODS AsyncConnection @@ -661,7 +688,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): ], attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"], ) -class AsyncEngine(ProxyComparable, AsyncConnectable): +class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): """An asyncio proxy for a :class:`_engine.Engine`. :class:`_asyncio.AsyncEngine` is acquired using the @@ -679,51 +706,60 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): # current transaction, info, etc. It should be possible to # create a new AsyncEngine that matches this one given only the # "sync" elements. - __slots__ = ("sync_engine", "_proxied") + __slots__ = "sync_engine" - _connection_cls = AsyncConnection + _connection_cls: Type[AsyncConnection] = AsyncConnection - _option_cls: type + sync_engine: Engine + """Reference to the sync-style :class:`_engine.Engine` this + :class:`_asyncio.AsyncEngine` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + """ class _trans_ctx(StartableContext): - def __init__(self, conn): + __slots__ = ("conn", "transaction") + + conn: AsyncConnection + transaction: AsyncTransaction + + def __init__(self, conn: AsyncConnection): self.conn = conn - async def start(self, is_ctxmanager=False): + async def start(self, is_ctxmanager: bool = False) -> AsyncConnection: await self.conn.start(is_ctxmanager=is_ctxmanager) self.transaction = self.conn.begin() await self.transaction.__aenter__() return self.conn - async def __aexit__(self, type_, value, traceback): + async def __aexit__( + self, type_: Any, value: Any, traceback: Any + ) -> None: await self.transaction.__aexit__(type_, value, traceback) await self.conn.close() - def __init__(self, sync_engine): + def __init__(self, sync_engine: Engine): if not sync_engine.dialect.is_async: raise exc.InvalidRequestError( "The asyncio extension requires an async driver to be used. " f"The loaded {sync_engine.dialect.driver!r} is not async." ) - self.sync_engine = self._proxied = self._assign_proxied(sync_engine) - - sync_engine: Engine - """Reference to the sync-style :class:`_engine.Engine` this - :class:`_asyncio.AsyncEngine` proxies requests towards. + self.sync_engine = self._assign_proxied(sync_engine) - This instance can be used as an event target. - - .. seealso:: - - :ref:`asyncio_events` - """ + @util.ro_non_memoized_property + def _proxied(self) -> Engine: + return self.sync_engine @classmethod - def _regenerate_proxy_for_target(cls, target): + def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: return AsyncEngine(target) - def begin(self): + def begin(self) -> AsyncEngine._trans_ctx: """Return a context manager which when entered will deliver an :class:`_asyncio.AsyncConnection` with an :class:`_asyncio.AsyncTransaction` established. @@ -741,7 +777,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): conn = self.connect() return self._trans_ctx(conn) - def connect(self): + def connect(self) -> AsyncConnection: """Return an :class:`_asyncio.AsyncConnection` object. The :class:`_asyncio.AsyncConnection` will procure a database @@ -759,7 +795,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): return self._connection_cls(self) - async def raw_connection(self): + async def raw_connection(self) -> PoolProxiedConnection: """Return a "raw" DBAPI connection from the connection pool. .. seealso:: @@ -769,7 +805,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): """ return await greenlet_spawn(self.sync_engine.raw_connection) - def execution_options(self, **opt): + def execution_options(self, **opt: Any) -> AsyncEngine: """Return a new :class:`_asyncio.AsyncEngine` that will provide :class:`_asyncio.AsyncConnection` objects with the given execution options. @@ -781,21 +817,31 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): return AsyncEngine(self.sync_engine.execution_options(**opt)) - async def dispose(self): + async def dispose(self, close: bool = True) -> None: + """Dispose of the connection pool used by this :class:`_asyncio.AsyncEngine`. - This will close all connection pool connections that are - **currently checked in**. See the documentation for the underlying - :meth:`_future.Engine.dispose` method for further notes. + :param close: if left at its default of ``True``, has the + effect of fully closing all **currently checked in** + database connections. Connections that are still checked out + will **not** be closed, however they will no longer be associated + with this :class:`_engine.Engine`, + so when they are closed individually, eventually the + :class:`_pool.Pool` which they are associated with will + be garbage collected and they will be closed out fully, if + not already closed on checkin. + + If set to ``False``, the previous connection pool is de-referenced, + and otherwise not touched in any way. .. seealso:: - :meth:`_future.Engine.dispose` + :meth:`_engine.Engine.dispose` """ - return await greenlet_spawn(self.sync_engine.dispose) + return await greenlet_spawn(self.sync_engine.dispose, close=close) # START PROXY METHODS AsyncEngine @@ -973,18 +1019,24 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): # END PROXY METHODS AsyncEngine -class AsyncTransaction(ProxyComparable, StartableContext): +class AsyncTransaction(ProxyComparable[Transaction], StartableContext): """An asyncio proxy for a :class:`_engine.Transaction`.""" __slots__ = ("connection", "sync_transaction", "nested") - def __init__(self, connection, nested=False): - self.connection = connection # AsyncConnection - self.sync_transaction = None # sqlalchemy.engine.Transaction + sync_transaction: Optional[Transaction] + connection: AsyncConnection + nested: bool + + def __init__(self, connection: AsyncConnection, nested: bool = False): + self.connection = connection + self.sync_transaction = None self.nested = nested @classmethod - def _regenerate_proxy_for_target(cls, target): + def _regenerate_proxy_for_target( + cls, target: Transaction + ) -> AsyncTransaction: sync_connection = target.connection sync_transaction = target nested = isinstance(target, NestedTransaction) @@ -1000,25 +1052,22 @@ class AsyncTransaction(ProxyComparable, StartableContext): obj.nested = nested return obj - def _sync_transaction(self): + @util.ro_non_memoized_property + def _proxied(self) -> Transaction: if not self.sync_transaction: self._raise_for_not_started() return self.sync_transaction @property - def _proxied(self): - return self.sync_transaction + def is_valid(self) -> bool: + return self._proxied.is_valid @property - def is_valid(self): - return self._sync_transaction().is_valid + def is_active(self) -> bool: + return self._proxied.is_active - @property - def is_active(self): - return self._sync_transaction().is_active - - async def close(self): - """Close this :class:`.Transaction`. + async def close(self) -> None: + """Close this :class:`.AsyncTransaction`. If this transaction is the base transaction in a begin/commit nesting, the transaction will rollback(). Otherwise, the @@ -1028,18 +1077,18 @@ class AsyncTransaction(ProxyComparable, StartableContext): an enclosing transaction. """ - await greenlet_spawn(self._sync_transaction().close) + await greenlet_spawn(self._proxied.close) - async def rollback(self): - """Roll back this :class:`.Transaction`.""" - await greenlet_spawn(self._sync_transaction().rollback) + async def rollback(self) -> None: + """Roll back this :class:`.AsyncTransaction`.""" + await greenlet_spawn(self._proxied.rollback) - async def commit(self): - """Commit this :class:`.Transaction`.""" + async def commit(self) -> None: + """Commit this :class:`.AsyncTransaction`.""" - await greenlet_spawn(self._sync_transaction().commit) + await greenlet_spawn(self._proxied.commit) - async def start(self, is_ctxmanager=False): + async def start(self, is_ctxmanager: bool = False) -> AsyncTransaction: """Start this :class:`_asyncio.AsyncTransaction` object's context outside of using a Python ``with:`` block. @@ -1047,24 +1096,36 @@ class AsyncTransaction(ProxyComparable, StartableContext): self.sync_transaction = self._assign_proxied( await greenlet_spawn( - self.connection._sync_connection().begin_nested + self.connection._proxied.begin_nested if self.nested - else self.connection._sync_connection().begin + else self.connection._proxied.begin ) ) if is_ctxmanager: self.sync_transaction.__enter__() return self - async def __aexit__(self, type_, value, traceback): - await greenlet_spawn( - self._sync_transaction().__exit__, type_, value, traceback - ) + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + await greenlet_spawn(self._proxied.__exit__, type_, value, traceback) + + +@overload +def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: + ... + + +@overload +def _get_sync_engine_or_connection( + async_engine: AsyncConnection, +) -> Connection: + ... -def _get_sync_engine_or_connection(async_engine): +def _get_sync_engine_or_connection( + async_engine: Union[AsyncEngine, AsyncConnection] +) -> Union[Engine, Connection]: if isinstance(async_engine, AsyncConnection): - return async_engine.sync_connection + return async_engine._proxied try: return async_engine.sync_engine @@ -1075,7 +1136,7 @@ def _get_sync_engine_or_connection(async_engine): @inspection._inspects(AsyncConnection) -def _no_insp_for_async_conn_yet(subject): +def _no_insp_for_async_conn_yet(subject: AsyncConnection) -> NoReturn: raise exc.NoInspectionAvailable( "Inspection on an AsyncConnection is currently not supported. " "Please use ``run_sync`` to pass a callable where it's possible " @@ -1085,7 +1146,7 @@ def _no_insp_for_async_conn_yet(subject): @inspection._inspects(AsyncEngine) -def _no_insp_for_async_engine_xyet(subject): +def _no_insp_for_async_engine_xyet(subject: AsyncEngine) -> NoReturn: raise exc.NoInspectionAvailable( "Inspection on an AsyncEngine is currently not supported. " "Please obtain a connection then use ``conn.run_sync`` to pass a " diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py deleted file mode 100644 index c5d5e0126..000000000 --- a/lib/sqlalchemy/ext/asyncio/events.py +++ /dev/null @@ -1,44 +0,0 @@ -# ext/asyncio/events.py -# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors -# <see AUTHORS file> -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php - -from .engine import AsyncConnectable -from .session import AsyncSession -from ...engine import events as engine_event -from ...orm import events as orm_event - - -class AsyncConnectionEvents(engine_event.ConnectionEvents): - _target_class_doc = "SomeEngine" - _dispatch_target = AsyncConnectable - - @classmethod - def _no_async_engine_events(cls): - raise NotImplementedError( - "asynchronous events are not implemented at this time. Apply " - "synchronous listeners to the AsyncEngine.sync_engine or " - "AsyncConnection.sync_connection attributes." - ) - - @classmethod - def _listen(cls, event_key, retval=False): - cls._no_async_engine_events() - - -class AsyncSessionEvents(orm_event.SessionEvents): - _target_class_doc = "SomeSession" - _dispatch_target = AsyncSession - - @classmethod - def _no_async_engine_events(cls): - raise NotImplementedError( - "asynchronous events are not implemented at this time. Apply " - "synchronous listeners to the AsyncSession.sync_session." - ) - - @classmethod - def _listen(cls, event_key, retval=False): - cls._no_async_engine_events() diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index 39718735c..a9db822a6 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -4,25 +4,49 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations import operator +from typing import Any +from typing import AsyncIterator +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from typing import TypeVar from . import exc as async_exc from ...engine.result import _NO_ROW +from ...engine.result import _R from ...engine.result import FilterResult from ...engine.result import FrozenResult from ...engine.result import MergedResult +from ...engine.result import ResultMetaData +from ...engine.row import Row +from ...engine.row import RowMapping from ...util.concurrency import greenlet_spawn +if TYPE_CHECKING: + from ...engine import CursorResult + from ...engine import Result + from ...engine.result import _KeyIndexType + from ...engine.result import _UniqueFilterType + from ...engine.result import RMKeyView -class AsyncCommon(FilterResult): - async def close(self): + +class AsyncCommon(FilterResult[_R]): + _real_result: Result + _metadata: ResultMetaData + + async def close(self) -> None: """Close this result.""" await greenlet_spawn(self._real_result.close) -class AsyncResult(AsyncCommon): +SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult") + + +class AsyncResult(AsyncCommon[Row]): """An asyncio wrapper around a :class:`_result.Result` object. The :class:`_asyncio.AsyncResult` only applies to statement executions that @@ -43,7 +67,7 @@ class AsyncResult(AsyncCommon): """ - def __init__(self, real_result): + def __init__(self, real_result: Result): self._real_result = real_result self._metadata = real_result._metadata @@ -56,14 +80,16 @@ class AsyncResult(AsyncCommon): "_row_getter", real_result.__dict__["_row_getter"] ) - def keys(self): + def keys(self) -> RMKeyView: """Return the :meth:`_engine.Result.keys` collection from the underlying :class:`_engine.Result`. """ return self._metadata.keys - def unique(self, strategy=None): + def unique( + self: SelfAsyncResult, strategy: Optional[_UniqueFilterType] = None + ) -> SelfAsyncResult: """Apply unique filtering to the objects returned by this :class:`_asyncio.AsyncResult`. @@ -75,7 +101,9 @@ class AsyncResult(AsyncCommon): self._unique_filter_state = (set(), strategy) return self - def columns(self, *col_expressions): + def columns( + self: SelfAsyncResult, *col_expressions: _KeyIndexType + ) -> SelfAsyncResult: r"""Establish the columns that should be returned in each row. Refer to :meth:`_engine.Result.columns` in the synchronous @@ -85,7 +113,9 @@ class AsyncResult(AsyncCommon): """ return self._column_slices(col_expressions) - async def partitions(self, size=None): + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[List[Row]]: """Iterate through sub-lists of rows of the size given. An async iterator is returned:: @@ -111,7 +141,7 @@ class AsyncResult(AsyncCommon): else: break - async def fetchone(self): + async def fetchone(self) -> Optional[Row]: """Fetch one row. When all rows are exhausted, returns None. @@ -131,9 +161,9 @@ class AsyncResult(AsyncCommon): if row is _NO_ROW: return None else: - return row # type: ignore[return-value] + return row - async def fetchmany(self, size=None): + async def fetchmany(self, size: Optional[int] = None) -> List[Row]: """Fetch many rows. When all rows are exhausted, returns an empty list. @@ -152,11 +182,9 @@ class AsyncResult(AsyncCommon): """ - return await greenlet_spawn( - self._manyrow_getter, self, size # type: ignore - ) + return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self): + async def all(self) -> List[Row]: """Return all rows in a list. Closes the result set after invocation. Subsequent invocations @@ -166,19 +194,19 @@ class AsyncResult(AsyncCommon): """ - return await greenlet_spawn(self._allrows) # type: ignore + return await greenlet_spawn(self._allrows) - def __aiter__(self): + def __aiter__(self) -> AsyncResult: return self - async def __anext__(self): + async def __anext__(self) -> Row: row = await greenlet_spawn(self._onerow_getter, self) if row is _NO_ROW: raise StopAsyncIteration() else: return row - async def first(self): + async def first(self) -> Optional[Row]: """Fetch the first row or None if no row is present. Closes the result set and discards remaining rows. @@ -201,7 +229,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, False, False, False) - async def one_or_none(self): + async def one_or_none(self) -> Optional[Row]: """Return at most one result or raise an exception. Returns ``None`` if the result has no rows. @@ -223,7 +251,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, False, False) - async def scalar_one(self): + async def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and @@ -238,7 +266,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, True, True) - async def scalar_one_or_none(self): + async def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and @@ -253,7 +281,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, False, True) - async def one(self): + async def one(self) -> Row: """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no @@ -284,7 +312,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, True, False) - async def scalar(self): + async def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. Returns None if there are no rows to fetch. @@ -300,7 +328,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, False, False, True) - async def freeze(self): + async def freeze(self) -> FrozenResult: """Return a callable object that will produce copies of this :class:`_asyncio.AsyncResult` when invoked. @@ -323,7 +351,7 @@ class AsyncResult(AsyncCommon): return await greenlet_spawn(FrozenResult, self) - def merge(self, *others): + def merge(self, *others: AsyncResult) -> MergedResult: """Merge this :class:`_asyncio.AsyncResult` with other compatible result objects. @@ -337,9 +365,12 @@ class AsyncResult(AsyncCommon): undefined. """ - return MergedResult(self._metadata, (self,) + others) + return MergedResult( + self._metadata, + (self._real_result,) + tuple(o._real_result for o in others), + ) - def scalars(self, index=0): + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: """Return an :class:`_asyncio.AsyncScalarResult` filtering object which will return single elements rather than :class:`_row.Row` objects. @@ -355,7 +386,7 @@ class AsyncResult(AsyncCommon): """ return AsyncScalarResult(self._real_result, index) - def mappings(self): + def mappings(self) -> AsyncMappingResult: """Apply a mappings filter to returned rows, returning an instance of :class:`_asyncio.AsyncMappingResult`. @@ -373,7 +404,12 @@ class AsyncResult(AsyncCommon): return AsyncMappingResult(self._real_result) -class AsyncScalarResult(AsyncCommon): +SelfAsyncScalarResult = TypeVar( + "SelfAsyncScalarResult", bound="AsyncScalarResult[Any]" +) + + +class AsyncScalarResult(AsyncCommon[_R]): """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values rather than :class:`_row.Row` values. @@ -389,7 +425,7 @@ class AsyncScalarResult(AsyncCommon): _generate_rows = False - def __init__(self, real_result, index): + def __init__(self, real_result: Result, index: _KeyIndexType): self._real_result = real_result if real_result._source_supports_scalars: @@ -401,7 +437,10 @@ class AsyncScalarResult(AsyncCommon): self._unique_filter_state = real_result._unique_filter_state - def unique(self, strategy=None): + def unique( + self: SelfAsyncScalarResult, + strategy: Optional[_UniqueFilterType] = None, + ) -> SelfAsyncScalarResult: """Apply unique filtering to the objects returned by this :class:`_asyncio.AsyncScalarResult`. @@ -411,7 +450,9 @@ class AsyncScalarResult(AsyncCommon): self._unique_filter_state = (set(), strategy) return self - async def partitions(self, size=None): + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[List[_R]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that @@ -429,12 +470,12 @@ class AsyncScalarResult(AsyncCommon): else: break - async def fetchall(self): + async def fetchall(self) -> List[_R]: """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method.""" return await greenlet_spawn(self._allrows) - async def fetchmany(self, size=None): + async def fetchmany(self, size: Optional[int] = None) -> List[_R]: """Fetch many objects. Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that @@ -444,7 +485,7 @@ class AsyncScalarResult(AsyncCommon): """ return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self): + async def all(self) -> List[_R]: """Return all scalar values in a list. Equivalent to :meth:`_asyncio.AsyncResult.all` except that @@ -454,17 +495,17 @@ class AsyncScalarResult(AsyncCommon): """ return await greenlet_spawn(self._allrows) - def __aiter__(self): + def __aiter__(self) -> AsyncScalarResult[_R]: return self - async def __anext__(self): + async def __anext__(self) -> _R: row = await greenlet_spawn(self._onerow_getter, self) if row is _NO_ROW: raise StopAsyncIteration() else: return row - async def first(self): + async def first(self) -> Optional[_R]: """Fetch the first object or None if no object is present. Equivalent to :meth:`_asyncio.AsyncResult.first` except that @@ -474,7 +515,7 @@ class AsyncScalarResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, False, False, False) - async def one_or_none(self): + async def one_or_none(self) -> Optional[_R]: """Return at most one object or raise an exception. Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that @@ -484,7 +525,7 @@ class AsyncScalarResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, False, False) - async def one(self): + async def one(self) -> _R: """Return exactly one object or raise an exception. Equivalent to :meth:`_asyncio.AsyncResult.one` except that @@ -495,7 +536,12 @@ class AsyncScalarResult(AsyncCommon): return await greenlet_spawn(self._only_one_row, True, True, False) -class AsyncMappingResult(AsyncCommon): +SelfAsyncMappingResult = TypeVar( + "SelfAsyncMappingResult", bound="AsyncMappingResult" +) + + +class AsyncMappingResult(AsyncCommon[RowMapping]): """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values rather than :class:`_engine.Row` values. @@ -513,14 +559,14 @@ class AsyncMappingResult(AsyncCommon): _post_creational_filter = operator.attrgetter("_mapping") - def __init__(self, result): + def __init__(self, result: Result): self._real_result = result self._unique_filter_state = result._unique_filter_state self._metadata = result._metadata if result._source_supports_scalars: self._metadata = self._metadata._reduce([0]) - def keys(self): + def keys(self) -> RMKeyView: """Return an iterable view which yields the string keys that would be represented by each :class:`.Row`. @@ -535,7 +581,10 @@ class AsyncMappingResult(AsyncCommon): """ return self._metadata.keys - def unique(self, strategy=None): + def unique( + self: SelfAsyncMappingResult, + strategy: Optional[_UniqueFilterType] = None, + ) -> SelfAsyncMappingResult: """Apply unique filtering to the objects returned by this :class:`_asyncio.AsyncMappingResult`. @@ -545,11 +594,16 @@ class AsyncMappingResult(AsyncCommon): self._unique_filter_state = (set(), strategy) return self - def columns(self, *col_expressions): + def columns( + self: SelfAsyncMappingResult, *col_expressions: _KeyIndexType + ) -> SelfAsyncMappingResult: r"""Establish the columns that should be returned in each row.""" return self._column_slices(col_expressions) - async def partitions(self, size=None): + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[List[RowMapping]]: + """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that @@ -567,12 +621,12 @@ class AsyncMappingResult(AsyncCommon): else: break - async def fetchall(self): + async def fetchall(self) -> List[RowMapping]: """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method.""" return await greenlet_spawn(self._allrows) - async def fetchone(self): + async def fetchone(self) -> Optional[RowMapping]: """Fetch one object. Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that @@ -587,8 +641,8 @@ class AsyncMappingResult(AsyncCommon): else: return row - async def fetchmany(self, size=None): - """Fetch many objects. + async def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]: + """Fetch many rows. Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that :class:`_result.RowMapping` values, rather than :class:`_result.Row` @@ -598,8 +652,8 @@ class AsyncMappingResult(AsyncCommon): return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self): - """Return all scalar values in a list. + async def all(self) -> List[RowMapping]: + """Return all rows in a list. Equivalent to :meth:`_asyncio.AsyncResult.all` except that :class:`_result.RowMapping` values, rather than :class:`_result.Row` @@ -609,17 +663,17 @@ class AsyncMappingResult(AsyncCommon): return await greenlet_spawn(self._allrows) - def __aiter__(self): + def __aiter__(self) -> AsyncMappingResult: return self - async def __anext__(self): + async def __anext__(self) -> RowMapping: row = await greenlet_spawn(self._onerow_getter, self) if row is _NO_ROW: raise StopAsyncIteration() else: return row - async def first(self): + async def first(self) -> Optional[RowMapping]: """Fetch the first object or None if no object is present. Equivalent to :meth:`_asyncio.AsyncResult.first` except that @@ -630,7 +684,7 @@ class AsyncMappingResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, False, False, False) - async def one_or_none(self): + async def one_or_none(self) -> Optional[RowMapping]: """Return at most one object or raise an exception. Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that @@ -640,7 +694,7 @@ class AsyncMappingResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, False, False) - async def one(self): + async def one(self) -> RowMapping: """Return exactly one object or raise an exception. Equivalent to :meth:`_asyncio.AsyncResult.one` except that @@ -651,11 +705,15 @@ class AsyncMappingResult(AsyncCommon): return await greenlet_spawn(self._only_one_row, True, True, False) -async def _ensure_sync_result(result, calling_method): +_RT = TypeVar("_RT", bound="Result") + + +async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT: + cursor_result: CursorResult if not result._is_cursor: - cursor_result = getattr(result, "raw", None) + cursor_result = getattr(result, "raw", None) # type: ignore else: - cursor_result = result + cursor_result = result # type: ignore if cursor_result and cursor_result.context._is_server_side: await greenlet_spawn(cursor_result.close) raise async_exc.AsyncMethodRequired( diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 0503076aa..0d6ae92b4 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -8,12 +8,50 @@ from __future__ import annotations from typing import Any - +from typing import Callable +from typing import Iterable +from typing import Iterator +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from .session import async_sessionmaker from .session import AsyncSession +from ... import exc as sa_exc from ... import util -from ...orm.scoping import ScopedSessionMixin +from ...orm.session import Session from ...util import create_proxy_methods from ...util import ScopedRegistry +from ...util import warn +from ...util import warn_deprecated + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .result import AsyncResult + from .result import AsyncScalarResult + from .session import AsyncSessionTransaction + from ...engine import Connection + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.interfaces import _ExecuteOptions + from ...engine.interfaces import _ExecuteOptionsParameter + from ...engine.result import ScalarResult + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...sql.base import Executable + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateArg @create_proxy_methods( @@ -62,7 +100,7 @@ from ...util import ScopedRegistry "info", ], ) -class async_scoped_session(ScopedSessionMixin): +class async_scoped_session: """Provides scoped management of :class:`.AsyncSession` objects. See the section :ref:`asyncio_scoped_session` for usage details. @@ -74,17 +112,23 @@ class async_scoped_session(ScopedSessionMixin): _support_async = True - def __init__(self, session_factory, scopefunc): + session_factory: async_sessionmaker + """The `session_factory` provided to `__init__` is stored in this + attribute and may be accessed at a later time. This can be useful when + a new non-scoped :class:`.AsyncSession` is needed.""" + + registry: ScopedRegistry[AsyncSession] + + def __init__( + self, + session_factory: async_sessionmaker, + scopefunc: Callable[[], Any], + ): """Construct a new :class:`_asyncio.async_scoped_session`. :param session_factory: a factory to create new :class:`_asyncio.AsyncSession` instances. This is usually, but not necessarily, an instance - of :class:`_orm.sessionmaker` which itself was passed the - :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_` - parameter:: - - async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession) - AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task) + of :class:`_asyncio.async_sessionmaker`. :param scopefunc: function which defines the current scope. A function such as ``asyncio.current_task`` @@ -96,10 +140,59 @@ class async_scoped_session(ScopedSessionMixin): self.registry = ScopedRegistry(session_factory, scopefunc) @property - def _proxied(self): + def _proxied(self) -> AsyncSession: return self.registry() - async def remove(self): + def __call__(self, **kw: Any) -> AsyncSession: + r"""Return the current :class:`.AsyncSession`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.AsyncSession` is not present. If the + :class:`.AsyncSession` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified." + ) + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + else: + sess = self.registry() + if not self._support_async and sess._is_asyncio: + warn_deprecated( + "Using `scoped_session` with asyncio is deprecated and " + "will raise an error in a future version. " + "Please use `async_scoped_session` instead.", + "1.4.23", + ) + return sess + + def configure(self, **kwargs: Any) -> None: + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) + + self.session_factory.configure(**kwargs) + + async def remove(self) -> None: """Dispose of the current :class:`.AsyncSession`, if present. Different from scoped_session's remove method, this method would use @@ -152,7 +245,9 @@ class async_scoped_session(ScopedSessionMixin): Proxied for the :class:`_orm.Session` class on behalf of the :class:`_asyncio.AsyncSession` class. - """ + + + """ # noqa: E501 return self._proxied.__iter__() @@ -199,7 +294,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.add_all(instances) - def begin(self): + def begin(self) -> AsyncSessionTransaction: r"""Return an :class:`_asyncio.AsyncSessionTransaction` object. .. container:: class_bases @@ -228,7 +323,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.begin() - def begin_nested(self): + def begin_nested(self) -> AsyncSessionTransaction: r"""Return an :class:`_asyncio.AsyncSessionTransaction` object which will begin a "nested" transaction, e.g. SAVEPOINT. @@ -247,7 +342,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.begin_nested() - async def close(self): + async def close(self) -> None: r"""Close out the transactional resources and ORM objects used by this :class:`_asyncio.AsyncSession`. @@ -284,7 +379,7 @@ class async_scoped_session(ScopedSessionMixin): return await self._proxied.close() - async def commit(self): + async def commit(self) -> None: r"""Commit the current transaction in progress. .. container:: class_bases @@ -296,7 +391,7 @@ class async_scoped_session(ScopedSessionMixin): return await self._proxied.commit() - async def connection(self, **kw): + async def connection(self, **kw: Any) -> AsyncConnection: r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this :class:`.Session` object's transactional state. @@ -321,7 +416,7 @@ class async_scoped_session(ScopedSessionMixin): return await self._proxied.connection(**kw) - async def delete(self, instance): + async def delete(self, instance: object) -> None: r"""Mark an instance as deleted. .. container:: class_bases @@ -345,12 +440,12 @@ class async_scoped_session(ScopedSessionMixin): async def execute( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result: r"""Execute a statement and return a buffered :class:`_engine.Result` object. @@ -519,7 +614,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.expunge_all() - async def flush(self, objects=None): + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: r"""Flush all the object changes to the database. .. container:: class_bases @@ -538,13 +633,15 @@ class async_scoped_session(ScopedSessionMixin): async def get( self, - entity, - ident, - options=None, - populate_existing=False, - with_for_update=None, - identity_token=None, - ): + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: r"""Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -568,9 +665,16 @@ class async_scoped_session(ScopedSessionMixin): populate_existing=populate_existing, with_for_update=with_for_update, identity_token=identity_token, + execution_options=execution_options, ) - def get_bind(self, mapper=None, clause=None, bind=None, **kw): + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: r"""Return a "bind" to which the synchronous proxied :class:`_orm.Session` is bound. @@ -724,7 +828,7 @@ class async_scoped_session(ScopedSessionMixin): instance, include_collections=include_collections ) - async def invalidate(self): + async def invalidate(self) -> None: r"""Close this Session, using connection invalidation. .. container:: class_bases @@ -738,7 +842,13 @@ class async_scoped_session(ScopedSessionMixin): return await self._proxied.invalidate() - async def merge(self, instance, load=True, options=None): + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: r"""Copy the state of a given instance into a corresponding instance within this :class:`_asyncio.AsyncSession`. @@ -757,8 +867,11 @@ class async_scoped_session(ScopedSessionMixin): return await self._proxied.merge(instance, load=load, options=options) async def refresh( - self, instance, attribute_names=None, with_for_update=None - ): + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: r"""Expire and refresh the attributes on the given instance. .. container:: class_bases @@ -785,7 +898,7 @@ class async_scoped_session(ScopedSessionMixin): with_for_update=with_for_update, ) - async def rollback(self): + async def rollback(self) -> None: r"""Rollback the current transaction in progress. .. container:: class_bases @@ -799,12 +912,12 @@ class async_scoped_session(ScopedSessionMixin): async def scalar( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: r"""Execute a statement and return a scalar result. .. container:: class_bases @@ -829,12 +942,12 @@ class async_scoped_session(ScopedSessionMixin): async def scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: r"""Execute a statement and return scalar results. .. container:: class_bases @@ -865,12 +978,12 @@ class async_scoped_session(ScopedSessionMixin): async def stream( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult: r"""Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object. @@ -892,12 +1005,12 @@ class async_scoped_session(ScopedSessionMixin): async def stream_scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: r"""Execute a statement and return a stream of scalar results. .. container:: class_bases @@ -1159,7 +1272,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.info @classmethod - async def close_all(self): + async def close_all(self) -> None: r"""Close all :class:`_asyncio.AsyncSession` sessions. .. container:: class_bases diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 769fe05bd..7d63b084c 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -7,17 +7,65 @@ from __future__ import annotations from typing import Any +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union from . import engine -from . import result as _result from .base import ReversibleProxy from .base import StartableContext from .result import _ensure_sync_result +from .result import AsyncResult +from .result import AsyncScalarResult from ... import util from ...orm import object_session from ...orm import Session +from ...orm import SessionTransaction from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn +from ...util.typing import Protocol + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .engine import AsyncEngine + from ...engine import Connection + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine import ScalarResult + from ...engine import Transaction + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.interfaces import _ExecuteOptions + from ...engine.interfaces import _ExecuteOptionsParameter + from ...event import dispatcher + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm.identity import IdentityMap + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...orm.session import _SessionBindKey + from ...sql.base import Executable + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateArg + +_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] + + +class _SyncSessionCallable(Protocol): + def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any: + ... + _EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True}) _STREAM_OPTIONS = util.immutabledict({"stream_results": True}) @@ -52,7 +100,7 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True}) "info", ], ) -class AsyncSession(ReversibleProxy): +class AsyncSession(ReversibleProxy[Session]): """Asyncio version of :class:`_orm.Session`. The :class:`_asyncio.AsyncSession` is a proxy for a traditional @@ -69,9 +117,15 @@ class AsyncSession(ReversibleProxy): _is_asyncio = True - dispatch = None + dispatch: dispatcher[Session] - def __init__(self, bind=None, binds=None, sync_session_class=None, **kw): + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + binds: Optional[Dict[_SessionBindKey, _AsyncSessionBind]] = None, + sync_session_class: Optional[Type[Session]] = None, + **kw: Any, + ): r"""Construct a new :class:`_asyncio.AsyncSession`. All parameters other than ``sync_session_class`` are passed to the @@ -90,14 +144,15 @@ class AsyncSession(ReversibleProxy): .. versionadded:: 1.4.24 """ - kw["future"] = True + sync_bind = sync_binds = None + if bind: self.bind = bind - bind = engine._get_sync_engine_or_connection(bind) + sync_bind = engine._get_sync_engine_or_connection(bind) if binds: self.binds = binds - binds = { + sync_binds = { key: engine._get_sync_engine_or_connection(b) for key, b in binds.items() } @@ -106,10 +161,10 @@ class AsyncSession(ReversibleProxy): self.sync_session_class = sync_session_class self.sync_session = self._proxied = self._assign_proxied( - self.sync_session_class(bind=bind, binds=binds, **kw) + self.sync_session_class(bind=sync_bind, binds=sync_binds, **kw) ) - sync_session_class = Session + sync_session_class: Type[Session] = Session """The class or callable that provides the underlying :class:`_orm.Session` instance for a particular :class:`_asyncio.AsyncSession`. @@ -138,9 +193,19 @@ class AsyncSession(ReversibleProxy): """ + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncSession.sync_session." + ) + async def refresh( - self, instance, attribute_names=None, with_for_update=None - ): + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: """Expire and refresh the attributes on the given instance. A query will be issued to the database and all attributes will be @@ -155,14 +220,16 @@ class AsyncSession(ReversibleProxy): """ - return await greenlet_spawn( + await greenlet_spawn( self.sync_session.refresh, instance, attribute_names=attribute_names, with_for_update=with_for_update, ) - async def run_sync(self, fn, *arg, **kw): + async def run_sync( + self, fn: _SyncSessionCallable, *arg: Any, **kw: Any + ) -> Any: """Invoke the given sync callable passing sync self as the first argument. @@ -191,12 +258,12 @@ class AsyncSession(ReversibleProxy): async def execute( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result: """Execute a statement and return a buffered :class:`_engine.Result` object. @@ -225,12 +292,12 @@ class AsyncSession(ReversibleProxy): async def scalar( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: """Execute a statement and return a scalar result. .. seealso:: @@ -250,12 +317,12 @@ class AsyncSession(ReversibleProxy): async def scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: """Execute a statement and return scalar results. :return: a :class:`_result.ScalarResult` object @@ -281,13 +348,16 @@ class AsyncSession(ReversibleProxy): async def get( self, - entity, - ident, - options=None, - populate_existing=False, - with_for_update=None, - identity_token=None, - ): + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: + """Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -297,7 +367,8 @@ class AsyncSession(ReversibleProxy): """ - return await greenlet_spawn( + + result_obj = await greenlet_spawn( self.sync_session.get, entity, ident, @@ -306,15 +377,17 @@ class AsyncSession(ReversibleProxy): with_for_update=with_for_update, identity_token=identity_token, ) + return result_obj async def stream( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult: + """Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object. @@ -335,16 +408,16 @@ class AsyncSession(ReversibleProxy): bind_arguments=bind_arguments, **kw, ) - return _result.AsyncResult(result) + return AsyncResult(result) async def stream_scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: """Execute a statement and return a stream of scalar results. :return: an :class:`_asyncio.AsyncScalarResult` object @@ -368,7 +441,7 @@ class AsyncSession(ReversibleProxy): ) return result.scalars() - async def delete(self, instance): + async def delete(self, instance: object) -> None: """Mark an instance as deleted. The database delete operation occurs upon ``flush()``. @@ -381,9 +454,15 @@ class AsyncSession(ReversibleProxy): :meth:`_orm.Session.delete` - main documentation for delete """ - return await greenlet_spawn(self.sync_session.delete, instance) + await greenlet_spawn(self.sync_session.delete, instance) - async def merge(self, instance, load=True, options=None): + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: """Copy the state of a given instance into a corresponding instance within this :class:`_asyncio.AsyncSession`. @@ -396,7 +475,7 @@ class AsyncSession(ReversibleProxy): self.sync_session.merge, instance, load=load, options=options ) - async def flush(self, objects=None): + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: """Flush all the object changes to the database. .. seealso:: @@ -406,7 +485,7 @@ class AsyncSession(ReversibleProxy): """ await greenlet_spawn(self.sync_session.flush, objects=objects) - def get_transaction(self): + def get_transaction(self) -> Optional[AsyncSessionTransaction]: """Return the current root transaction in progress, if any. :return: an :class:`_asyncio.AsyncSessionTransaction` object, or @@ -421,7 +500,7 @@ class AsyncSession(ReversibleProxy): else: return None - def get_nested_transaction(self): + def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]: """Return the current nested transaction in progress, if any. :return: an :class:`_asyncio.AsyncSessionTransaction` object, or @@ -437,7 +516,13 @@ class AsyncSession(ReversibleProxy): else: return None - def get_bind(self, mapper=None, clause=None, bind=None, **kw): + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: """Return a "bind" to which the synchronous proxied :class:`_orm.Session` is bound. @@ -515,7 +600,7 @@ class AsyncSession(ReversibleProxy): mapper=mapper, clause=clause, bind=bind, **kw ) - async def connection(self, **kw): + async def connection(self, **kw: Any) -> AsyncConnection: r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this :class:`.Session` object's transactional state. @@ -539,7 +624,7 @@ class AsyncSession(ReversibleProxy): sync_connection ) - def begin(self): + def begin(self) -> AsyncSessionTransaction: """Return an :class:`_asyncio.AsyncSessionTransaction` object. The underlying :class:`_orm.Session` will perform the @@ -562,7 +647,7 @@ class AsyncSession(ReversibleProxy): return AsyncSessionTransaction(self) - def begin_nested(self): + def begin_nested(self) -> AsyncSessionTransaction: """Return an :class:`_asyncio.AsyncSessionTransaction` object which will begin a "nested" transaction, e.g. SAVEPOINT. @@ -575,15 +660,15 @@ class AsyncSession(ReversibleProxy): return AsyncSessionTransaction(self, nested=True) - async def rollback(self): + async def rollback(self) -> None: """Rollback the current transaction in progress.""" - return await greenlet_spawn(self.sync_session.rollback) + await greenlet_spawn(self.sync_session.rollback) - async def commit(self): + async def commit(self) -> None: """Commit the current transaction in progress.""" - return await greenlet_spawn(self.sync_session.commit) + await greenlet_spawn(self.sync_session.commit) - async def close(self): + async def close(self) -> None: """Close out the transactional resources and ORM objects used by this :class:`_asyncio.AsyncSession`. @@ -613,25 +698,25 @@ class AsyncSession(ReversibleProxy): """ return await greenlet_spawn(self.sync_session.close) - async def invalidate(self): + async def invalidate(self) -> None: """Close this Session, using connection invalidation. For a complete description, see :meth:`_orm.Session.invalidate`. """ - return await greenlet_spawn(self.sync_session.invalidate) + await greenlet_spawn(self.sync_session.invalidate) @classmethod - async def close_all(self): + async def close_all(self) -> None: """Close all :class:`_asyncio.AsyncSession` sessions.""" - return await greenlet_spawn(self.sync_session.close_all) + await greenlet_spawn(self.sync_session.close_all) - async def __aenter__(self): + async def __aenter__(self) -> AsyncSession: return self - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.close() - def _maker_context_manager(self): + def _maker_context_manager(self) -> _AsyncSessionContextManager: # TODO: can this use asynccontextmanager ?? return _AsyncSessionContextManager(self) @@ -1142,21 +1227,159 @@ class AsyncSession(ReversibleProxy): # END PROXY METHODS AsyncSession +class async_sessionmaker: + """A configurable :class:`.AsyncSession` factory. + + The :class:`.async_sessionmaker` factory works in the same way as the + :class:`.sessionmaker` factory, to generate new :class:`.AsyncSession` + objects when called, creating them given + the configurational arguments established here. + + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import async_sessionmaker + + async def main(): + # an AsyncEngine, which the AsyncSession will use for connection + # resources + engine = create_async_engine('postgresql+asycncpg://scott:tiger@localhost/') + + AsyncSession = async_sessionmaker(engine) + + async with async_session() as session: + session.add(some_object) + session.add(some_other_object) + await session.commit() + + .. versionadded:: 2.0 :class:`.asyncio_sessionmaker` provides a + :class:`.sessionmaker` class that's dedicated to the + :class:`.AsyncSession` object, including pep-484 typing support. + + .. seealso:: + + :ref:`asyncio_orm` - shows example use + + :class:`.sessionmaker` - general overview of the + :class:`.sessionmaker` architecture + + + :ref:`session_getting` - introductory text on creating + sessions using :class:`.sessionmaker`. + + """ # noqa E501 + + class_: Type[AsyncSession] + + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + class_: Type[AsyncSession] = AsyncSession, + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[Dict[Any, Any]] = None, + **kw: Any, + ): + r"""Construct a new :class:`.async_sessionmaker`. + + All arguments here except for ``class_`` correspond to arguments + accepted by :class:`.Session` directly. See the + :meth:`.AsyncSession.__init__` docstring for more details on + parameters. + + + """ + kw["bind"] = bind + kw["autoflush"] = autoflush + kw["expire_on_commit"] = expire_on_commit + if info is not None: + kw["info"] = info + self.kw = kw + self.class_ = class_ + + def begin(self) -> _AsyncSessionContextManager: + """Produce a context manager that both provides a new + :class:`_orm.AsyncSession` as well as a transaction that commits. + + + e.g.:: + + async def main(): + Session = async_sessionmaker(some_engine) + + async with Session.begin() as session: + session.add(some_object) + + # commits transaction, closes session + + + """ + + session = self() + return session._maker_context_manager() + + def __call__(self, **local_kw: Any) -> AsyncSession: + """Produce a new :class:`.AsyncSession` object using the configuration + established in this :class:`.async_sessionmaker`. + + In Python, the ``__call__`` method is invoked on an object when + it is "called" in the same way as a function:: + + AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False) + session = AsyncSession() # invokes sessionmaker.__call__() + + """ # noqa E501 + for k, v in self.kw.items(): + if k == "info" and "info" in local_kw: + d = v.copy() + d.update(local_kw["info"]) + local_kw["info"] = d + else: + local_kw.setdefault(k, v) + return self.class_(**local_kw) + + def configure(self, **new_kw: Any) -> None: + """(Re)configure the arguments for this async_sessionmaker. + + e.g.:: + + AsyncSession = async_sessionmaker(some_engine) + + AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://')) + """ # noqa E501 + + self.kw.update(new_kw) + + def __repr__(self) -> str: + return "%s(class_=%r, %s)" % ( + self.__class__.__name__, + self.class_.__name__, + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()), + ) + + class _AsyncSessionContextManager: - def __init__(self, async_session): + __slots__ = ("async_session", "trans") + + async_session: AsyncSession + trans: AsyncSessionTransaction + + def __init__(self, async_session: AsyncSession): self.async_session = async_session - async def __aenter__(self): + async def __aenter__(self) -> AsyncSession: self.trans = self.async_session.begin() await self.trans.__aenter__() return self.async_session - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.trans.__aexit__(type_, value, traceback) await self.async_session.__aexit__(type_, value, traceback) -class AsyncSessionTransaction(ReversibleProxy, StartableContext): +class AsyncSessionTransaction( + ReversibleProxy[SessionTransaction], StartableContext +): """A wrapper for the ORM :class:`_orm.SessionTransaction` object. This object is provided so that a transaction-holding object @@ -1174,36 +1397,41 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext): __slots__ = ("session", "sync_transaction", "nested") - def __init__(self, session, nested=False): + session: AsyncSession + sync_transaction: Optional[SessionTransaction] + + def __init__(self, session: AsyncSession, nested: bool = False): self.session = session self.nested = nested self.sync_transaction = None @property - def is_active(self): + def is_active(self) -> bool: return ( self._sync_transaction() is not None and self._sync_transaction().is_active ) - def _sync_transaction(self): + def _sync_transaction(self) -> SessionTransaction: if not self.sync_transaction: self._raise_for_not_started() return self.sync_transaction - async def rollback(self): + async def rollback(self) -> None: """Roll back this :class:`_asyncio.AsyncTransaction`.""" await greenlet_spawn(self._sync_transaction().rollback) - async def commit(self): + async def commit(self) -> None: """Commit this :class:`_asyncio.AsyncTransaction`.""" await greenlet_spawn(self._sync_transaction().commit) - async def start(self, is_ctxmanager=False): + async def start( + self, is_ctxmanager: bool = False + ) -> AsyncSessionTransaction: self.sync_transaction = self._assign_proxied( await greenlet_spawn( - self.session.sync_session.begin_nested + self.session.sync_session.begin_nested # type: ignore if self.nested else self.session.sync_session.begin ) @@ -1212,13 +1440,13 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext): self.sync_transaction.__enter__() return self - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await greenlet_spawn( self._sync_transaction().__exit__, type_, value, traceback ) -def async_object_session(instance): +def async_object_session(instance: object) -> Optional[AsyncSession]: """Return the :class:`_asyncio.AsyncSession` to which the given instance belongs. @@ -1247,7 +1475,7 @@ def async_object_session(instance): return None -def async_session(session: Session) -> AsyncSession: +def async_session(session: Session) -> Optional[AsyncSession]: """Return the :class:`_asyncio.AsyncSession` which is proxying the given :class:`_orm.Session` object, if any. @@ -1260,4 +1488,4 @@ def async_session(session: Session) -> AsyncSession: return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) -_instance_state._async_provider = async_session +_instance_state._async_provider = async_session # type: ignore |