diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-26 14:35:03 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-31 15:17:48 -0400 |
commit | d24cd5e96d7f8e47c86b5013a7f989a15e2eec89 (patch) | |
tree | 4291dbaeea6b78164e492da183cff5e8e7dfd9d6 /lib/sqlalchemy/ext/asyncio/session.py | |
parent | 5531cec630ee75bfd7f5848cfe622c769be5ae48 (diff) | |
download | sqlalchemy-d24cd5e96d7f8e47c86b5013a7f989a15e2eec89.tar.gz |
establish sessionmaker and async_sessionmaker as generic
This is so that custom Session and AsyncSession classes
can be typed for these factories. Added appropriate
typevars to `__call__()`, `__enter__()` and other methods
so that a custom Session or AsyncSession subclass is carried
through.
Fixes: #7656
Change-Id: Ia2b8c1f22b4410db26005c3285f6ba3d13d7f0e0
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/session.py')
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/session.py | 33 |
1 files changed, 18 insertions, 15 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index eac2e5806..be3414cef 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -8,6 +8,7 @@ from __future__ import annotations from typing import Any from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import NoReturn @@ -698,7 +699,8 @@ class AsyncSession(ReversibleProxy[Session]): from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.orm import Session, sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session # construct async engines w/ async drivers engines = { @@ -721,8 +723,7 @@ class AsyncSession(ReversibleProxy[Session]): ].sync_engine # apply to AsyncSession using sync_session_class - AsyncSessionMaker = sessionmaker( - class_=AsyncSession, + AsyncSessionMaker = async_sessionmaker( sync_session_class=RoutingSession ) @@ -850,14 +851,13 @@ class AsyncSession(ReversibleProxy[Session]): """Close all :class:`_asyncio.AsyncSession` sessions.""" await greenlet_spawn(self.sync_session.close_all) - async def __aenter__(self) -> AsyncSession: + async def __aenter__(self: _AS) -> _AS: return self async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.close() - def _maker_context_manager(self) -> _AsyncSessionContextManager: - # TODO: can this use asynccontextmanager ?? + def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]: return _AsyncSessionContextManager(self) # START PROXY METHODS AsyncSession @@ -1367,7 +1367,10 @@ class AsyncSession(ReversibleProxy[Session]): # END PROXY METHODS AsyncSession -class async_sessionmaker: +_AS = TypeVar("_AS", bound="AsyncSession") + + +class async_sessionmaker(Generic[_AS]): """A configurable :class:`.AsyncSession` factory. The :class:`.async_sessionmaker` factory works in the same way as the @@ -1409,12 +1412,12 @@ class async_sessionmaker: """ # noqa E501 - class_: Type[AsyncSession] + class_: Type[_AS] def __init__( self, bind: Optional[_AsyncSessionBind] = None, - class_: Type[AsyncSession] = AsyncSession, + class_: Type[_AS] = AsyncSession, # type: ignore autoflush: bool = True, expire_on_commit: bool = True, info: Optional[_InfoType] = None, @@ -1437,7 +1440,7 @@ class async_sessionmaker: self.kw = kw self.class_ = class_ - def begin(self) -> _AsyncSessionContextManager: + def begin(self) -> _AsyncSessionContextManager[_AS]: """Produce a context manager that both provides a new :class:`_orm.AsyncSession` as well as a transaction that commits. @@ -1458,7 +1461,7 @@ class async_sessionmaker: session = self() return session._maker_context_manager() - def __call__(self, **local_kw: Any) -> AsyncSession: + def __call__(self, **local_kw: Any) -> _AS: """Produce a new :class:`.AsyncSession` object using the configuration established in this :class:`.async_sessionmaker`. @@ -1498,16 +1501,16 @@ class async_sessionmaker: ) -class _AsyncSessionContextManager: +class _AsyncSessionContextManager(Generic[_AS]): __slots__ = ("async_session", "trans") - async_session: AsyncSession + async_session: _AS trans: AsyncSessionTransaction - def __init__(self, async_session: AsyncSession): + def __init__(self, async_session: _AS): self.async_session = async_session - async def __aenter__(self) -> AsyncSession: + async def __aenter__(self) -> _AS: self.trans = self.async_session.begin() await self.trans.__aenter__() return self.async_session |