summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio/session.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-10 15:42:35 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-11 22:11:07 -0400
commita45e2284dad17fbbba3bea9d5e5304aab21c8c94 (patch)
treeac31614f2d53059570e2edffe731baf384baea23 /lib/sqlalchemy/ext/asyncio/session.py
parentaa9cd878e8249a4a758c7f968e929e92fede42a5 (diff)
downloadsqlalchemy-a45e2284dad17fbbba3bea9d5e5304aab21c8c94.tar.gz
pep-484: asyncio
in this patch the asyncio/events.py module, which existed only to raise errors when trying to attach event listeners, is removed, as we were already coding an asyncio-specific workaround in upstream Pool / Session to raise this error, just moved the error out to the target and did the same thing for Engine. We also add an async_sessionmaker class. The initial rationale here is because sessionmaker() is hardcoded to Session subclasses, and there's not a way to get the use case of sessionmaker(class_=AsyncSession) to type correctly without changing the sessionmaker() symbol itself to be a function and not a class, which gets too complicated for what this is. Additionally, _SessionClassMethods has only three methods on it, one of which is not usable with asyncio (close_all()), the others not generally used from the session class. Change-Id: I064a5fa5d91cc8d5bbe9597437536e37b4e801fe
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/session.py')
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py406
1 files changed, 317 insertions, 89 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
index 769fe05bd..7d63b084c 100644
--- a/lib/sqlalchemy/ext/asyncio/session.py
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -7,17 +7,65 @@
from __future__ import annotations
from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import NoReturn
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from . import engine
-from . import result as _result
from .base import ReversibleProxy
from .base import StartableContext
from .result import _ensure_sync_result
+from .result import AsyncResult
+from .result import AsyncScalarResult
from ... import util
from ...orm import object_session
from ...orm import Session
+from ...orm import SessionTransaction
from ...orm import state as _instance_state
from ...util.concurrency import greenlet_spawn
+from ...util.typing import Protocol
+
+if TYPE_CHECKING:
+ from .engine import AsyncConnection
+ from .engine import AsyncEngine
+ from ...engine import Connection
+ from ...engine import Engine
+ from ...engine import Result
+ from ...engine import Row
+ from ...engine import ScalarResult
+ from ...engine import Transaction
+ from ...engine.interfaces import _CoreAnyExecuteParams
+ from ...engine.interfaces import _CoreSingleExecuteParams
+ from ...engine.interfaces import _ExecuteOptions
+ from ...engine.interfaces import _ExecuteOptionsParameter
+ from ...event import dispatcher
+ from ...orm._typing import _IdentityKeyType
+ from ...orm._typing import _O
+ from ...orm.identity import IdentityMap
+ from ...orm.interfaces import ORMOption
+ from ...orm.session import _BindArguments
+ from ...orm.session import _EntityBindKey
+ from ...orm.session import _PKIdentityArgument
+ from ...orm.session import _SessionBind
+ from ...orm.session import _SessionBindKey
+ from ...sql.base import Executable
+ from ...sql.elements import ClauseElement
+ from ...sql.selectable import ForUpdateArg
+
+_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
+
+
+class _SyncSessionCallable(Protocol):
+ def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any:
+ ...
+
_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
_STREAM_OPTIONS = util.immutabledict({"stream_results": True})
@@ -52,7 +100,7 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True})
"info",
],
)
-class AsyncSession(ReversibleProxy):
+class AsyncSession(ReversibleProxy[Session]):
"""Asyncio version of :class:`_orm.Session`.
The :class:`_asyncio.AsyncSession` is a proxy for a traditional
@@ -69,9 +117,15 @@ class AsyncSession(ReversibleProxy):
_is_asyncio = True
- dispatch = None
+ dispatch: dispatcher[Session]
- def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
+ def __init__(
+ self,
+ bind: Optional[_AsyncSessionBind] = None,
+ binds: Optional[Dict[_SessionBindKey, _AsyncSessionBind]] = None,
+ sync_session_class: Optional[Type[Session]] = None,
+ **kw: Any,
+ ):
r"""Construct a new :class:`_asyncio.AsyncSession`.
All parameters other than ``sync_session_class`` are passed to the
@@ -90,14 +144,15 @@ class AsyncSession(ReversibleProxy):
.. versionadded:: 1.4.24
"""
- kw["future"] = True
+ sync_bind = sync_binds = None
+
if bind:
self.bind = bind
- bind = engine._get_sync_engine_or_connection(bind)
+ sync_bind = engine._get_sync_engine_or_connection(bind)
if binds:
self.binds = binds
- binds = {
+ sync_binds = {
key: engine._get_sync_engine_or_connection(b)
for key, b in binds.items()
}
@@ -106,10 +161,10 @@ class AsyncSession(ReversibleProxy):
self.sync_session_class = sync_session_class
self.sync_session = self._proxied = self._assign_proxied(
- self.sync_session_class(bind=bind, binds=binds, **kw)
+ self.sync_session_class(bind=sync_bind, binds=sync_binds, **kw)
)
- sync_session_class = Session
+ sync_session_class: Type[Session] = Session
"""The class or callable that provides the
underlying :class:`_orm.Session` instance for a particular
:class:`_asyncio.AsyncSession`.
@@ -138,9 +193,19 @@ class AsyncSession(ReversibleProxy):
"""
+ @classmethod
+ def _no_async_engine_events(cls) -> NoReturn:
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncSession.sync_session."
+ )
+
async def refresh(
- self, instance, attribute_names=None, with_for_update=None
- ):
+ self,
+ instance: object,
+ attribute_names: Optional[Iterable[str]] = None,
+ with_for_update: Optional[ForUpdateArg] = None,
+ ) -> None:
"""Expire and refresh the attributes on the given instance.
A query will be issued to the database and all attributes will be
@@ -155,14 +220,16 @@ class AsyncSession(ReversibleProxy):
"""
- return await greenlet_spawn(
+ await greenlet_spawn(
self.sync_session.refresh,
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
)
- async def run_sync(self, fn, *arg, **kw):
+ async def run_sync(
+ self, fn: _SyncSessionCallable, *arg: Any, **kw: Any
+ ) -> Any:
"""Invoke the given sync callable passing sync self as the first
argument.
@@ -191,12 +258,12 @@ class AsyncSession(ReversibleProxy):
async def execute(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Result:
"""Execute a statement and return a buffered
:class:`_engine.Result` object.
@@ -225,12 +292,12 @@ class AsyncSession(ReversibleProxy):
async def scalar(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
"""Execute a statement and return a scalar result.
.. seealso::
@@ -250,12 +317,12 @@ class AsyncSession(ReversibleProxy):
async def scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
"""Execute a statement and return scalar results.
:return: a :class:`_result.ScalarResult` object
@@ -281,13 +348,16 @@ class AsyncSession(ReversibleProxy):
async def get(
self,
- entity,
- ident,
- options=None,
- populate_existing=False,
- with_for_update=None,
- identity_token=None,
- ):
+ entity: _EntityBindKey[_O],
+ ident: _PKIdentityArgument,
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[ForUpdateArg] = None,
+ identity_token: Optional[Any] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ ) -> Optional[_O]:
+
"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
@@ -297,7 +367,8 @@ class AsyncSession(ReversibleProxy):
"""
- return await greenlet_spawn(
+
+ result_obj = await greenlet_spawn(
self.sync_session.get,
entity,
ident,
@@ -306,15 +377,17 @@ class AsyncSession(ReversibleProxy):
with_for_update=with_for_update,
identity_token=identity_token,
)
+ return result_obj
async def stream(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult:
+
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
@@ -335,16 +408,16 @@ class AsyncSession(ReversibleProxy):
bind_arguments=bind_arguments,
**kw,
)
- return _result.AsyncResult(result)
+ return AsyncResult(result)
async def stream_scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[Any]:
"""Execute a statement and return a stream of scalar results.
:return: an :class:`_asyncio.AsyncScalarResult` object
@@ -368,7 +441,7 @@ class AsyncSession(ReversibleProxy):
)
return result.scalars()
- async def delete(self, instance):
+ async def delete(self, instance: object) -> None:
"""Mark an instance as deleted.
The database delete operation occurs upon ``flush()``.
@@ -381,9 +454,15 @@ class AsyncSession(ReversibleProxy):
:meth:`_orm.Session.delete` - main documentation for delete
"""
- return await greenlet_spawn(self.sync_session.delete, instance)
+ await greenlet_spawn(self.sync_session.delete, instance)
- async def merge(self, instance, load=True, options=None):
+ async def merge(
+ self,
+ instance: _O,
+ *,
+ load: bool = True,
+ options: Optional[Sequence[ORMOption]] = None,
+ ) -> _O:
"""Copy the state of a given instance into a corresponding instance
within this :class:`_asyncio.AsyncSession`.
@@ -396,7 +475,7 @@ class AsyncSession(ReversibleProxy):
self.sync_session.merge, instance, load=load, options=options
)
- async def flush(self, objects=None):
+ async def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
"""Flush all the object changes to the database.
.. seealso::
@@ -406,7 +485,7 @@ class AsyncSession(ReversibleProxy):
"""
await greenlet_spawn(self.sync_session.flush, objects=objects)
- def get_transaction(self):
+ def get_transaction(self) -> Optional[AsyncSessionTransaction]:
"""Return the current root transaction in progress, if any.
:return: an :class:`_asyncio.AsyncSessionTransaction` object, or
@@ -421,7 +500,7 @@ class AsyncSession(ReversibleProxy):
else:
return None
- def get_nested_transaction(self):
+ def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]:
"""Return the current nested transaction in progress, if any.
:return: an :class:`_asyncio.AsyncSessionTransaction` object, or
@@ -437,7 +516,13 @@ class AsyncSession(ReversibleProxy):
else:
return None
- def get_bind(self, mapper=None, clause=None, bind=None, **kw):
+ def get_bind(
+ self,
+ mapper: Optional[_EntityBindKey[_O]] = None,
+ clause: Optional[ClauseElement] = None,
+ bind: Optional[_SessionBind] = None,
+ **kw: Any,
+ ) -> Union[Engine, Connection]:
"""Return a "bind" to which the synchronous proxied :class:`_orm.Session`
is bound.
@@ -515,7 +600,7 @@ class AsyncSession(ReversibleProxy):
mapper=mapper, clause=clause, bind=bind, **kw
)
- async def connection(self, **kw):
+ async def connection(self, **kw: Any) -> AsyncConnection:
r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
this :class:`.Session` object's transactional state.
@@ -539,7 +624,7 @@ class AsyncSession(ReversibleProxy):
sync_connection
)
- def begin(self):
+ def begin(self) -> AsyncSessionTransaction:
"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
The underlying :class:`_orm.Session` will perform the
@@ -562,7 +647,7 @@ class AsyncSession(ReversibleProxy):
return AsyncSessionTransaction(self)
- def begin_nested(self):
+ def begin_nested(self) -> AsyncSessionTransaction:
"""Return an :class:`_asyncio.AsyncSessionTransaction` object
which will begin a "nested" transaction, e.g. SAVEPOINT.
@@ -575,15 +660,15 @@ class AsyncSession(ReversibleProxy):
return AsyncSessionTransaction(self, nested=True)
- async def rollback(self):
+ async def rollback(self) -> None:
"""Rollback the current transaction in progress."""
- return await greenlet_spawn(self.sync_session.rollback)
+ await greenlet_spawn(self.sync_session.rollback)
- async def commit(self):
+ async def commit(self) -> None:
"""Commit the current transaction in progress."""
- return await greenlet_spawn(self.sync_session.commit)
+ await greenlet_spawn(self.sync_session.commit)
- async def close(self):
+ async def close(self) -> None:
"""Close out the transactional resources and ORM objects used by this
:class:`_asyncio.AsyncSession`.
@@ -613,25 +698,25 @@ class AsyncSession(ReversibleProxy):
"""
return await greenlet_spawn(self.sync_session.close)
- async def invalidate(self):
+ async def invalidate(self) -> None:
"""Close this Session, using connection invalidation.
For a complete description, see :meth:`_orm.Session.invalidate`.
"""
- return await greenlet_spawn(self.sync_session.invalidate)
+ await greenlet_spawn(self.sync_session.invalidate)
@classmethod
- async def close_all(self):
+ async def close_all(self) -> None:
"""Close all :class:`_asyncio.AsyncSession` sessions."""
- return await greenlet_spawn(self.sync_session.close_all)
+ await greenlet_spawn(self.sync_session.close_all)
- async def __aenter__(self):
+ async def __aenter__(self) -> AsyncSession:
return self
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.close()
- def _maker_context_manager(self):
+ def _maker_context_manager(self) -> _AsyncSessionContextManager:
# TODO: can this use asynccontextmanager ??
return _AsyncSessionContextManager(self)
@@ -1142,21 +1227,159 @@ class AsyncSession(ReversibleProxy):
# END PROXY METHODS AsyncSession
+class async_sessionmaker:
+ """A configurable :class:`.AsyncSession` factory.
+
+ The :class:`.async_sessionmaker` factory works in the same way as the
+ :class:`.sessionmaker` factory, to generate new :class:`.AsyncSession`
+ objects when called, creating them given
+ the configurational arguments established here.
+
+ e.g.::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ from sqlalchemy.ext.asyncio import async_sessionmaker
+
+ async def main():
+ # an AsyncEngine, which the AsyncSession will use for connection
+ # resources
+ engine = create_async_engine('postgresql+asycncpg://scott:tiger@localhost/')
+
+ AsyncSession = async_sessionmaker(engine)
+
+ async with async_session() as session:
+ session.add(some_object)
+ session.add(some_other_object)
+ await session.commit()
+
+ .. versionadded:: 2.0 :class:`.asyncio_sessionmaker` provides a
+ :class:`.sessionmaker` class that's dedicated to the
+ :class:`.AsyncSession` object, including pep-484 typing support.
+
+ .. seealso::
+
+ :ref:`asyncio_orm` - shows example use
+
+ :class:`.sessionmaker` - general overview of the
+ :class:`.sessionmaker` architecture
+
+
+ :ref:`session_getting` - introductory text on creating
+ sessions using :class:`.sessionmaker`.
+
+ """ # noqa E501
+
+ class_: Type[AsyncSession]
+
+ def __init__(
+ self,
+ bind: Optional[_AsyncSessionBind] = None,
+ class_: Type[AsyncSession] = AsyncSession,
+ autoflush: bool = True,
+ expire_on_commit: bool = True,
+ info: Optional[Dict[Any, Any]] = None,
+ **kw: Any,
+ ):
+ r"""Construct a new :class:`.async_sessionmaker`.
+
+ All arguments here except for ``class_`` correspond to arguments
+ accepted by :class:`.Session` directly. See the
+ :meth:`.AsyncSession.__init__` docstring for more details on
+ parameters.
+
+
+ """
+ kw["bind"] = bind
+ kw["autoflush"] = autoflush
+ kw["expire_on_commit"] = expire_on_commit
+ if info is not None:
+ kw["info"] = info
+ self.kw = kw
+ self.class_ = class_
+
+ def begin(self) -> _AsyncSessionContextManager:
+ """Produce a context manager that both provides a new
+ :class:`_orm.AsyncSession` as well as a transaction that commits.
+
+
+ e.g.::
+
+ async def main():
+ Session = async_sessionmaker(some_engine)
+
+ async with Session.begin() as session:
+ session.add(some_object)
+
+ # commits transaction, closes session
+
+
+ """
+
+ session = self()
+ return session._maker_context_manager()
+
+ def __call__(self, **local_kw: Any) -> AsyncSession:
+ """Produce a new :class:`.AsyncSession` object using the configuration
+ established in this :class:`.async_sessionmaker`.
+
+ In Python, the ``__call__`` method is invoked on an object when
+ it is "called" in the same way as a function::
+
+ AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False)
+ session = AsyncSession() # invokes sessionmaker.__call__()
+
+ """ # noqa E501
+ for k, v in self.kw.items():
+ if k == "info" and "info" in local_kw:
+ d = v.copy()
+ d.update(local_kw["info"])
+ local_kw["info"] = d
+ else:
+ local_kw.setdefault(k, v)
+ return self.class_(**local_kw)
+
+ def configure(self, **new_kw: Any) -> None:
+ """(Re)configure the arguments for this async_sessionmaker.
+
+ e.g.::
+
+ AsyncSession = async_sessionmaker(some_engine)
+
+ AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://'))
+ """ # noqa E501
+
+ self.kw.update(new_kw)
+
+ def __repr__(self) -> str:
+ return "%s(class_=%r, %s)" % (
+ self.__class__.__name__,
+ self.class_.__name__,
+ ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()),
+ )
+
+
class _AsyncSessionContextManager:
- def __init__(self, async_session):
+ __slots__ = ("async_session", "trans")
+
+ async_session: AsyncSession
+ trans: AsyncSessionTransaction
+
+ def __init__(self, async_session: AsyncSession):
self.async_session = async_session
- async def __aenter__(self):
+ async def __aenter__(self) -> AsyncSession:
self.trans = self.async_session.begin()
await self.trans.__aenter__()
return self.async_session
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.trans.__aexit__(type_, value, traceback)
await self.async_session.__aexit__(type_, value, traceback)
-class AsyncSessionTransaction(ReversibleProxy, StartableContext):
+class AsyncSessionTransaction(
+ ReversibleProxy[SessionTransaction], StartableContext
+):
"""A wrapper for the ORM :class:`_orm.SessionTransaction` object.
This object is provided so that a transaction-holding object
@@ -1174,36 +1397,41 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext):
__slots__ = ("session", "sync_transaction", "nested")
- def __init__(self, session, nested=False):
+ session: AsyncSession
+ sync_transaction: Optional[SessionTransaction]
+
+ def __init__(self, session: AsyncSession, nested: bool = False):
self.session = session
self.nested = nested
self.sync_transaction = None
@property
- def is_active(self):
+ def is_active(self) -> bool:
return (
self._sync_transaction() is not None
and self._sync_transaction().is_active
)
- def _sync_transaction(self):
+ def _sync_transaction(self) -> SessionTransaction:
if not self.sync_transaction:
self._raise_for_not_started()
return self.sync_transaction
- async def rollback(self):
+ async def rollback(self) -> None:
"""Roll back this :class:`_asyncio.AsyncTransaction`."""
await greenlet_spawn(self._sync_transaction().rollback)
- async def commit(self):
+ async def commit(self) -> None:
"""Commit this :class:`_asyncio.AsyncTransaction`."""
await greenlet_spawn(self._sync_transaction().commit)
- async def start(self, is_ctxmanager=False):
+ async def start(
+ self, is_ctxmanager: bool = False
+ ) -> AsyncSessionTransaction:
self.sync_transaction = self._assign_proxied(
await greenlet_spawn(
- self.session.sync_session.begin_nested
+ self.session.sync_session.begin_nested # type: ignore
if self.nested
else self.session.sync_session.begin
)
@@ -1212,13 +1440,13 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext):
self.sync_transaction.__enter__()
return self
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await greenlet_spawn(
self._sync_transaction().__exit__, type_, value, traceback
)
-def async_object_session(instance):
+def async_object_session(instance: object) -> Optional[AsyncSession]:
"""Return the :class:`_asyncio.AsyncSession` to which the given instance
belongs.
@@ -1247,7 +1475,7 @@ def async_object_session(instance):
return None
-def async_session(session: Session) -> AsyncSession:
+def async_session(session: Session) -> Optional[AsyncSession]:
"""Return the :class:`_asyncio.AsyncSession` which is proxying the given
:class:`_orm.Session` object, if any.
@@ -1260,4 +1488,4 @@ def async_session(session: Session) -> AsyncSession:
return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
-_instance_state._async_provider = async_session
+_instance_state._async_provider = async_session # type: ignore