diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-26 14:35:03 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-31 15:17:48 -0400 |
commit | d24cd5e96d7f8e47c86b5013a7f989a15e2eec89 (patch) | |
tree | 4291dbaeea6b78164e492da183cff5e8e7dfd9d6 /lib/sqlalchemy/ext/asyncio | |
parent | 5531cec630ee75bfd7f5848cfe622c769be5ae48 (diff) | |
download | sqlalchemy-d24cd5e96d7f8e47c86b5013a7f989a15e2eec89.tar.gz |
establish sessionmaker and async_sessionmaker as generic
This is so that custom Session and AsyncSession classes
can be typed for these factories. Added appropriate
typevars to `__call__()`, `__enter__()` and other methods
so that a custom Session or AsyncSession subclass is carried
through.
Fixes: #7656
Change-Id: Ia2b8c1f22b4410db26005c3285f6ba3d13d7f0e0
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio')
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/scoping.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/session.py | 33 |
3 files changed, 39 insertions, 29 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 6d07d843c..97d69fcbd 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -1068,10 +1068,15 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): @property def engine(self) -> Any: - r""".. container:: class_bases + r"""Returns this :class:`.Engine`. - Proxied for the :class:`_engine.Engine` class - on behalf of the :class:`_asyncio.AsyncEngine` class. + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + Used for legacy schemes that accept :class:`.Connection` / + :class:`.Engine` objects within the same variable. """ # noqa: E501 diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 22a060a0d..8d31dd07d 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -9,6 +9,7 @@ from __future__ import annotations from typing import Any from typing import Callable +from typing import Generic from typing import Iterable from typing import Iterator from typing import Optional @@ -20,6 +21,7 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from .session import _AS from .session import async_sessionmaker from .session import AsyncSession from ... import exc as sa_exc @@ -104,7 +106,7 @@ _T = TypeVar("_T", bound=Any) "info", ], ) -class async_scoped_session: +class async_scoped_session(Generic[_AS]): """Provides scoped management of :class:`.AsyncSession` objects. See the section :ref:`asyncio_scoped_session` for usage details. @@ -116,16 +118,16 @@ class async_scoped_session: _support_async = True - session_factory: async_sessionmaker + session_factory: async_sessionmaker[_AS] """The `session_factory` provided to `__init__` is stored in this attribute and may be accessed at a later time. This can be useful when a new non-scoped :class:`.AsyncSession` is needed.""" - registry: ScopedRegistry[AsyncSession] + registry: ScopedRegistry[_AS] def __init__( self, - session_factory: async_sessionmaker, + session_factory: async_sessionmaker[_AS], scopefunc: Callable[[], Any], ): """Construct a new :class:`_asyncio.async_scoped_session`. @@ -144,10 +146,10 @@ class async_scoped_session: self.registry = ScopedRegistry(session_factory, scopefunc) @property - def _proxied(self) -> AsyncSession: + def _proxied(self) -> _AS: return self.registry() - def __call__(self, **kw: Any) -> AsyncSession: + def __call__(self, **kw: Any) -> _AS: r"""Return the current :class:`.AsyncSession`, creating it using the :attr:`.scoped_session.session_factory` if not present. @@ -450,8 +452,8 @@ class async_scoped_session: This method may also be used to establish execution options for the database connection used by the current transaction. - .. versionadded:: 1.4.24 Added **kw arguments which are passed through - to the underlying :meth:`_orm.Session.connection` method. + .. versionadded:: 1.4.24 Added \**kw arguments which are passed + through to the underlying :meth:`_orm.Session.connection` method. .. seealso:: @@ -752,7 +754,8 @@ class async_scoped_session: from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.orm import Session, sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session # construct async engines w/ async drivers engines = { @@ -775,8 +778,7 @@ class async_scoped_session: ].sync_engine # apply to AsyncSession using sync_session_class - AsyncSessionMaker = sessionmaker( - class_=AsyncSession, + AsyncSessionMaker = async_sessionmaker( sync_session_class=RoutingSession ) diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index eac2e5806..be3414cef 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -8,6 +8,7 @@ from __future__ import annotations from typing import Any from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import NoReturn @@ -698,7 +699,8 @@ class AsyncSession(ReversibleProxy[Session]): from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.orm import Session, sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session # construct async engines w/ async drivers engines = { @@ -721,8 +723,7 @@ class AsyncSession(ReversibleProxy[Session]): ].sync_engine # apply to AsyncSession using sync_session_class - AsyncSessionMaker = sessionmaker( - class_=AsyncSession, + AsyncSessionMaker = async_sessionmaker( sync_session_class=RoutingSession ) @@ -850,14 +851,13 @@ class AsyncSession(ReversibleProxy[Session]): """Close all :class:`_asyncio.AsyncSession` sessions.""" await greenlet_spawn(self.sync_session.close_all) - async def __aenter__(self) -> AsyncSession: + async def __aenter__(self: _AS) -> _AS: return self async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.close() - def _maker_context_manager(self) -> _AsyncSessionContextManager: - # TODO: can this use asynccontextmanager ?? + def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]: return _AsyncSessionContextManager(self) # START PROXY METHODS AsyncSession @@ -1367,7 +1367,10 @@ class AsyncSession(ReversibleProxy[Session]): # END PROXY METHODS AsyncSession -class async_sessionmaker: +_AS = TypeVar("_AS", bound="AsyncSession") + + +class async_sessionmaker(Generic[_AS]): """A configurable :class:`.AsyncSession` factory. The :class:`.async_sessionmaker` factory works in the same way as the @@ -1409,12 +1412,12 @@ class async_sessionmaker: """ # noqa E501 - class_: Type[AsyncSession] + class_: Type[_AS] def __init__( self, bind: Optional[_AsyncSessionBind] = None, - class_: Type[AsyncSession] = AsyncSession, + class_: Type[_AS] = AsyncSession, # type: ignore autoflush: bool = True, expire_on_commit: bool = True, info: Optional[_InfoType] = None, @@ -1437,7 +1440,7 @@ class async_sessionmaker: self.kw = kw self.class_ = class_ - def begin(self) -> _AsyncSessionContextManager: + def begin(self) -> _AsyncSessionContextManager[_AS]: """Produce a context manager that both provides a new :class:`_orm.AsyncSession` as well as a transaction that commits. @@ -1458,7 +1461,7 @@ class async_sessionmaker: session = self() return session._maker_context_manager() - def __call__(self, **local_kw: Any) -> AsyncSession: + def __call__(self, **local_kw: Any) -> _AS: """Produce a new :class:`.AsyncSession` object using the configuration established in this :class:`.async_sessionmaker`. @@ -1498,16 +1501,16 @@ class async_sessionmaker: ) -class _AsyncSessionContextManager: +class _AsyncSessionContextManager(Generic[_AS]): __slots__ = ("async_session", "trans") - async_session: AsyncSession + async_session: _AS trans: AsyncSessionTransaction - def __init__(self, async_session: AsyncSession): + def __init__(self, async_session: _AS): self.async_session = async_session - async def __aenter__(self) -> AsyncSession: + async def __aenter__(self) -> _AS: self.trans = self.async_session.begin() await self.trans.__aenter__() return self.async_session |