diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-10 15:42:35 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-11 22:11:07 -0400 |
commit | a45e2284dad17fbbba3bea9d5e5304aab21c8c94 (patch) | |
tree | ac31614f2d53059570e2edffe731baf384baea23 /lib/sqlalchemy/ext/asyncio/session.py | |
parent | aa9cd878e8249a4a758c7f968e929e92fede42a5 (diff) | |
download | sqlalchemy-a45e2284dad17fbbba3bea9d5e5304aab21c8c94.tar.gz |
pep-484: asyncio
in this patch the asyncio/events.py module, which
existed only to raise errors when trying to attach event
listeners, is removed, as we were already coding an asyncio-specific
workaround in upstream Pool / Session to raise this error,
just moved the error out to the target and did the same thing
for Engine.
We also add an async_sessionmaker class. The initial rationale
here is because sessionmaker() is hardcoded to Session subclasses,
and there's not a way to get the use case of
sessionmaker(class_=AsyncSession) to type correctly without changing
the sessionmaker() symbol itself to be a function and not a class,
which gets too complicated for what this is. Additionally,
_SessionClassMethods has only three methods on it, one of which
is not usable with asyncio (close_all()), the others
not generally used from the session class.
Change-Id: I064a5fa5d91cc8d5bbe9597437536e37b4e801fe
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/session.py')
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/session.py | 406 |
1 files changed, 317 insertions, 89 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 769fe05bd..7d63b084c 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -7,17 +7,65 @@ from __future__ import annotations from typing import Any +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union from . import engine -from . import result as _result from .base import ReversibleProxy from .base import StartableContext from .result import _ensure_sync_result +from .result import AsyncResult +from .result import AsyncScalarResult from ... import util from ...orm import object_session from ...orm import Session +from ...orm import SessionTransaction from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn +from ...util.typing import Protocol + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .engine import AsyncEngine + from ...engine import Connection + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine import ScalarResult + from ...engine import Transaction + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.interfaces import _ExecuteOptions + from ...engine.interfaces import _ExecuteOptionsParameter + from ...event import dispatcher + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm.identity import IdentityMap + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...orm.session import _SessionBindKey + from ...sql.base import Executable + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateArg + +_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] + + +class _SyncSessionCallable(Protocol): + def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any: + ... + _EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True}) _STREAM_OPTIONS = util.immutabledict({"stream_results": True}) @@ -52,7 +100,7 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True}) "info", ], ) -class AsyncSession(ReversibleProxy): +class AsyncSession(ReversibleProxy[Session]): """Asyncio version of :class:`_orm.Session`. The :class:`_asyncio.AsyncSession` is a proxy for a traditional @@ -69,9 +117,15 @@ class AsyncSession(ReversibleProxy): _is_asyncio = True - dispatch = None + dispatch: dispatcher[Session] - def __init__(self, bind=None, binds=None, sync_session_class=None, **kw): + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + binds: Optional[Dict[_SessionBindKey, _AsyncSessionBind]] = None, + sync_session_class: Optional[Type[Session]] = None, + **kw: Any, + ): r"""Construct a new :class:`_asyncio.AsyncSession`. All parameters other than ``sync_session_class`` are passed to the @@ -90,14 +144,15 @@ class AsyncSession(ReversibleProxy): .. versionadded:: 1.4.24 """ - kw["future"] = True + sync_bind = sync_binds = None + if bind: self.bind = bind - bind = engine._get_sync_engine_or_connection(bind) + sync_bind = engine._get_sync_engine_or_connection(bind) if binds: self.binds = binds - binds = { + sync_binds = { key: engine._get_sync_engine_or_connection(b) for key, b in binds.items() } @@ -106,10 +161,10 @@ class AsyncSession(ReversibleProxy): self.sync_session_class = sync_session_class self.sync_session = self._proxied = self._assign_proxied( - self.sync_session_class(bind=bind, binds=binds, **kw) + self.sync_session_class(bind=sync_bind, binds=sync_binds, **kw) ) - sync_session_class = Session + sync_session_class: Type[Session] = Session """The class or callable that provides the underlying :class:`_orm.Session` instance for a particular :class:`_asyncio.AsyncSession`. @@ -138,9 +193,19 @@ class AsyncSession(ReversibleProxy): """ + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncSession.sync_session." + ) + async def refresh( - self, instance, attribute_names=None, with_for_update=None - ): + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: """Expire and refresh the attributes on the given instance. A query will be issued to the database and all attributes will be @@ -155,14 +220,16 @@ class AsyncSession(ReversibleProxy): """ - return await greenlet_spawn( + await greenlet_spawn( self.sync_session.refresh, instance, attribute_names=attribute_names, with_for_update=with_for_update, ) - async def run_sync(self, fn, *arg, **kw): + async def run_sync( + self, fn: _SyncSessionCallable, *arg: Any, **kw: Any + ) -> Any: """Invoke the given sync callable passing sync self as the first argument. @@ -191,12 +258,12 @@ class AsyncSession(ReversibleProxy): async def execute( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result: """Execute a statement and return a buffered :class:`_engine.Result` object. @@ -225,12 +292,12 @@ class AsyncSession(ReversibleProxy): async def scalar( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: """Execute a statement and return a scalar result. .. seealso:: @@ -250,12 +317,12 @@ class AsyncSession(ReversibleProxy): async def scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: """Execute a statement and return scalar results. :return: a :class:`_result.ScalarResult` object @@ -281,13 +348,16 @@ class AsyncSession(ReversibleProxy): async def get( self, - entity, - ident, - options=None, - populate_existing=False, - with_for_update=None, - identity_token=None, - ): + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: + """Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -297,7 +367,8 @@ class AsyncSession(ReversibleProxy): """ - return await greenlet_spawn( + + result_obj = await greenlet_spawn( self.sync_session.get, entity, ident, @@ -306,15 +377,17 @@ class AsyncSession(ReversibleProxy): with_for_update=with_for_update, identity_token=identity_token, ) + return result_obj async def stream( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult: + """Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object. @@ -335,16 +408,16 @@ class AsyncSession(ReversibleProxy): bind_arguments=bind_arguments, **kw, ) - return _result.AsyncResult(result) + return AsyncResult(result) async def stream_scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: """Execute a statement and return a stream of scalar results. :return: an :class:`_asyncio.AsyncScalarResult` object @@ -368,7 +441,7 @@ class AsyncSession(ReversibleProxy): ) return result.scalars() - async def delete(self, instance): + async def delete(self, instance: object) -> None: """Mark an instance as deleted. The database delete operation occurs upon ``flush()``. @@ -381,9 +454,15 @@ class AsyncSession(ReversibleProxy): :meth:`_orm.Session.delete` - main documentation for delete """ - return await greenlet_spawn(self.sync_session.delete, instance) + await greenlet_spawn(self.sync_session.delete, instance) - async def merge(self, instance, load=True, options=None): + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: """Copy the state of a given instance into a corresponding instance within this :class:`_asyncio.AsyncSession`. @@ -396,7 +475,7 @@ class AsyncSession(ReversibleProxy): self.sync_session.merge, instance, load=load, options=options ) - async def flush(self, objects=None): + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: """Flush all the object changes to the database. .. seealso:: @@ -406,7 +485,7 @@ class AsyncSession(ReversibleProxy): """ await greenlet_spawn(self.sync_session.flush, objects=objects) - def get_transaction(self): + def get_transaction(self) -> Optional[AsyncSessionTransaction]: """Return the current root transaction in progress, if any. :return: an :class:`_asyncio.AsyncSessionTransaction` object, or @@ -421,7 +500,7 @@ class AsyncSession(ReversibleProxy): else: return None - def get_nested_transaction(self): + def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]: """Return the current nested transaction in progress, if any. :return: an :class:`_asyncio.AsyncSessionTransaction` object, or @@ -437,7 +516,13 @@ class AsyncSession(ReversibleProxy): else: return None - def get_bind(self, mapper=None, clause=None, bind=None, **kw): + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: """Return a "bind" to which the synchronous proxied :class:`_orm.Session` is bound. @@ -515,7 +600,7 @@ class AsyncSession(ReversibleProxy): mapper=mapper, clause=clause, bind=bind, **kw ) - async def connection(self, **kw): + async def connection(self, **kw: Any) -> AsyncConnection: r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this :class:`.Session` object's transactional state. @@ -539,7 +624,7 @@ class AsyncSession(ReversibleProxy): sync_connection ) - def begin(self): + def begin(self) -> AsyncSessionTransaction: """Return an :class:`_asyncio.AsyncSessionTransaction` object. The underlying :class:`_orm.Session` will perform the @@ -562,7 +647,7 @@ class AsyncSession(ReversibleProxy): return AsyncSessionTransaction(self) - def begin_nested(self): + def begin_nested(self) -> AsyncSessionTransaction: """Return an :class:`_asyncio.AsyncSessionTransaction` object which will begin a "nested" transaction, e.g. SAVEPOINT. @@ -575,15 +660,15 @@ class AsyncSession(ReversibleProxy): return AsyncSessionTransaction(self, nested=True) - async def rollback(self): + async def rollback(self) -> None: """Rollback the current transaction in progress.""" - return await greenlet_spawn(self.sync_session.rollback) + await greenlet_spawn(self.sync_session.rollback) - async def commit(self): + async def commit(self) -> None: """Commit the current transaction in progress.""" - return await greenlet_spawn(self.sync_session.commit) + await greenlet_spawn(self.sync_session.commit) - async def close(self): + async def close(self) -> None: """Close out the transactional resources and ORM objects used by this :class:`_asyncio.AsyncSession`. @@ -613,25 +698,25 @@ class AsyncSession(ReversibleProxy): """ return await greenlet_spawn(self.sync_session.close) - async def invalidate(self): + async def invalidate(self) -> None: """Close this Session, using connection invalidation. For a complete description, see :meth:`_orm.Session.invalidate`. """ - return await greenlet_spawn(self.sync_session.invalidate) + await greenlet_spawn(self.sync_session.invalidate) @classmethod - async def close_all(self): + async def close_all(self) -> None: """Close all :class:`_asyncio.AsyncSession` sessions.""" - return await greenlet_spawn(self.sync_session.close_all) + await greenlet_spawn(self.sync_session.close_all) - async def __aenter__(self): + async def __aenter__(self) -> AsyncSession: return self - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.close() - def _maker_context_manager(self): + def _maker_context_manager(self) -> _AsyncSessionContextManager: # TODO: can this use asynccontextmanager ?? return _AsyncSessionContextManager(self) @@ -1142,21 +1227,159 @@ class AsyncSession(ReversibleProxy): # END PROXY METHODS AsyncSession +class async_sessionmaker: + """A configurable :class:`.AsyncSession` factory. + + The :class:`.async_sessionmaker` factory works in the same way as the + :class:`.sessionmaker` factory, to generate new :class:`.AsyncSession` + objects when called, creating them given + the configurational arguments established here. + + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import async_sessionmaker + + async def main(): + # an AsyncEngine, which the AsyncSession will use for connection + # resources + engine = create_async_engine('postgresql+asycncpg://scott:tiger@localhost/') + + AsyncSession = async_sessionmaker(engine) + + async with async_session() as session: + session.add(some_object) + session.add(some_other_object) + await session.commit() + + .. versionadded:: 2.0 :class:`.asyncio_sessionmaker` provides a + :class:`.sessionmaker` class that's dedicated to the + :class:`.AsyncSession` object, including pep-484 typing support. + + .. seealso:: + + :ref:`asyncio_orm` - shows example use + + :class:`.sessionmaker` - general overview of the + :class:`.sessionmaker` architecture + + + :ref:`session_getting` - introductory text on creating + sessions using :class:`.sessionmaker`. + + """ # noqa E501 + + class_: Type[AsyncSession] + + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + class_: Type[AsyncSession] = AsyncSession, + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[Dict[Any, Any]] = None, + **kw: Any, + ): + r"""Construct a new :class:`.async_sessionmaker`. + + All arguments here except for ``class_`` correspond to arguments + accepted by :class:`.Session` directly. See the + :meth:`.AsyncSession.__init__` docstring for more details on + parameters. + + + """ + kw["bind"] = bind + kw["autoflush"] = autoflush + kw["expire_on_commit"] = expire_on_commit + if info is not None: + kw["info"] = info + self.kw = kw + self.class_ = class_ + + def begin(self) -> _AsyncSessionContextManager: + """Produce a context manager that both provides a new + :class:`_orm.AsyncSession` as well as a transaction that commits. + + + e.g.:: + + async def main(): + Session = async_sessionmaker(some_engine) + + async with Session.begin() as session: + session.add(some_object) + + # commits transaction, closes session + + + """ + + session = self() + return session._maker_context_manager() + + def __call__(self, **local_kw: Any) -> AsyncSession: + """Produce a new :class:`.AsyncSession` object using the configuration + established in this :class:`.async_sessionmaker`. + + In Python, the ``__call__`` method is invoked on an object when + it is "called" in the same way as a function:: + + AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False) + session = AsyncSession() # invokes sessionmaker.__call__() + + """ # noqa E501 + for k, v in self.kw.items(): + if k == "info" and "info" in local_kw: + d = v.copy() + d.update(local_kw["info"]) + local_kw["info"] = d + else: + local_kw.setdefault(k, v) + return self.class_(**local_kw) + + def configure(self, **new_kw: Any) -> None: + """(Re)configure the arguments for this async_sessionmaker. + + e.g.:: + + AsyncSession = async_sessionmaker(some_engine) + + AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://')) + """ # noqa E501 + + self.kw.update(new_kw) + + def __repr__(self) -> str: + return "%s(class_=%r, %s)" % ( + self.__class__.__name__, + self.class_.__name__, + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()), + ) + + class _AsyncSessionContextManager: - def __init__(self, async_session): + __slots__ = ("async_session", "trans") + + async_session: AsyncSession + trans: AsyncSessionTransaction + + def __init__(self, async_session: AsyncSession): self.async_session = async_session - async def __aenter__(self): + async def __aenter__(self) -> AsyncSession: self.trans = self.async_session.begin() await self.trans.__aenter__() return self.async_session - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.trans.__aexit__(type_, value, traceback) await self.async_session.__aexit__(type_, value, traceback) -class AsyncSessionTransaction(ReversibleProxy, StartableContext): +class AsyncSessionTransaction( + ReversibleProxy[SessionTransaction], StartableContext +): """A wrapper for the ORM :class:`_orm.SessionTransaction` object. This object is provided so that a transaction-holding object @@ -1174,36 +1397,41 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext): __slots__ = ("session", "sync_transaction", "nested") - def __init__(self, session, nested=False): + session: AsyncSession + sync_transaction: Optional[SessionTransaction] + + def __init__(self, session: AsyncSession, nested: bool = False): self.session = session self.nested = nested self.sync_transaction = None @property - def is_active(self): + def is_active(self) -> bool: return ( self._sync_transaction() is not None and self._sync_transaction().is_active ) - def _sync_transaction(self): + def _sync_transaction(self) -> SessionTransaction: if not self.sync_transaction: self._raise_for_not_started() return self.sync_transaction - async def rollback(self): + async def rollback(self) -> None: """Roll back this :class:`_asyncio.AsyncTransaction`.""" await greenlet_spawn(self._sync_transaction().rollback) - async def commit(self): + async def commit(self) -> None: """Commit this :class:`_asyncio.AsyncTransaction`.""" await greenlet_spawn(self._sync_transaction().commit) - async def start(self, is_ctxmanager=False): + async def start( + self, is_ctxmanager: bool = False + ) -> AsyncSessionTransaction: self.sync_transaction = self._assign_proxied( await greenlet_spawn( - self.session.sync_session.begin_nested + self.session.sync_session.begin_nested # type: ignore if self.nested else self.session.sync_session.begin ) @@ -1212,13 +1440,13 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext): self.sync_transaction.__enter__() return self - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await greenlet_spawn( self._sync_transaction().__exit__, type_, value, traceback ) -def async_object_session(instance): +def async_object_session(instance: object) -> Optional[AsyncSession]: """Return the :class:`_asyncio.AsyncSession` to which the given instance belongs. @@ -1247,7 +1475,7 @@ def async_object_session(instance): return None -def async_session(session: Session) -> AsyncSession: +def async_session(session: Session) -> Optional[AsyncSession]: """Return the :class:`_asyncio.AsyncSession` which is proxying the given :class:`_orm.Session` object, if any. @@ -1260,4 +1488,4 @@ def async_session(session: Session) -> AsyncSession: return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) -_instance_state._async_provider = async_session +_instance_state._async_provider = async_session # type: ignore |