diff options
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/session.py')
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/session.py | 33 |
1 files changed, 18 insertions, 15 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index eac2e5806..be3414cef 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -8,6 +8,7 @@ from __future__ import annotations from typing import Any from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import NoReturn @@ -698,7 +699,8 @@ class AsyncSession(ReversibleProxy[Session]): from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.orm import Session, sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session # construct async engines w/ async drivers engines = { @@ -721,8 +723,7 @@ class AsyncSession(ReversibleProxy[Session]): ].sync_engine # apply to AsyncSession using sync_session_class - AsyncSessionMaker = sessionmaker( - class_=AsyncSession, + AsyncSessionMaker = async_sessionmaker( sync_session_class=RoutingSession ) @@ -850,14 +851,13 @@ class AsyncSession(ReversibleProxy[Session]): """Close all :class:`_asyncio.AsyncSession` sessions.""" await greenlet_spawn(self.sync_session.close_all) - async def __aenter__(self) -> AsyncSession: + async def __aenter__(self: _AS) -> _AS: return self async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.close() - def _maker_context_manager(self) -> _AsyncSessionContextManager: - # TODO: can this use asynccontextmanager ?? + def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]: return _AsyncSessionContextManager(self) # START PROXY METHODS AsyncSession @@ -1367,7 +1367,10 @@ class AsyncSession(ReversibleProxy[Session]): # END PROXY METHODS AsyncSession -class async_sessionmaker: +_AS = TypeVar("_AS", bound="AsyncSession") + + +class async_sessionmaker(Generic[_AS]): """A configurable :class:`.AsyncSession` factory. The :class:`.async_sessionmaker` factory works in the same way as the @@ -1409,12 +1412,12 @@ class async_sessionmaker: """ # noqa E501 - class_: Type[AsyncSession] + class_: Type[_AS] def __init__( self, bind: Optional[_AsyncSessionBind] = None, - class_: Type[AsyncSession] = AsyncSession, + class_: Type[_AS] = AsyncSession, # type: ignore autoflush: bool = True, expire_on_commit: bool = True, info: Optional[_InfoType] = None, @@ -1437,7 +1440,7 @@ class async_sessionmaker: self.kw = kw self.class_ = class_ - def begin(self) -> _AsyncSessionContextManager: + def begin(self) -> _AsyncSessionContextManager[_AS]: """Produce a context manager that both provides a new :class:`_orm.AsyncSession` as well as a transaction that commits. @@ -1458,7 +1461,7 @@ class async_sessionmaker: session = self() return session._maker_context_manager() - def __call__(self, **local_kw: Any) -> AsyncSession: + def __call__(self, **local_kw: Any) -> _AS: """Produce a new :class:`.AsyncSession` object using the configuration established in this :class:`.async_sessionmaker`. @@ -1498,16 +1501,16 @@ class async_sessionmaker: ) -class _AsyncSessionContextManager: +class _AsyncSessionContextManager(Generic[_AS]): __slots__ = ("async_session", "trans") - async_session: AsyncSession + async_session: _AS trans: AsyncSessionTransaction - def __init__(self, async_session: AsyncSession): + def __init__(self, async_session: _AS): self.async_session = async_session - async def __aenter__(self) -> AsyncSession: + async def __aenter__(self) -> _AS: self.trans = self.async_session.begin() await self.trans.__aenter__() return self.async_session |