diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-26 14:35:03 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-31 15:17:48 -0400 |
commit | d24cd5e96d7f8e47c86b5013a7f989a15e2eec89 (patch) | |
tree | 4291dbaeea6b78164e492da183cff5e8e7dfd9d6 /lib/sqlalchemy/orm/session.py | |
parent | 5531cec630ee75bfd7f5848cfe622c769be5ae48 (diff) | |
download | sqlalchemy-d24cd5e96d7f8e47c86b5013a7f989a15e2eec89.tar.gz |
establish sessionmaker and async_sessionmaker as generic
This is so that custom Session and AsyncSession classes
can be typed for these factories. Added appropriate
typevars to `__call__()`, `__enter__()` and other methods
so that a custom Session or AsyncSession subclass is carried
through.
Fixes: #7656
Change-Id: Ia2b8c1f22b4410db26005c3285f6ba3d13d7f0e0
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index d72e78c9e..788821b98 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -17,6 +17,7 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import List @@ -1420,14 +1421,14 @@ class Session(_SessionClassMethods, EventTarget): connection_callable: Optional[_ConnectionCallableProto] = None - def __enter__(self) -> Session: + def __enter__(self: _S) -> _S: return self def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: self.close() @contextlib.contextmanager - def _maker_context_manager(self) -> Iterator[Session]: + def _maker_context_manager(self: _S) -> Iterator[_S]: with self: with self.begin(): yield self @@ -4398,7 +4399,10 @@ class Session(_SessionClassMethods, EventTarget): return util.IdentitySet(list(self._new.values())) -class sessionmaker(_SessionClassMethods): +_S = TypeVar("_S", bound="Session") + + +class sessionmaker(_SessionClassMethods, Generic[_S]): """A configurable :class:`.Session` factory. The :class:`.sessionmaker` factory generates new @@ -4493,12 +4497,12 @@ class sessionmaker(_SessionClassMethods): """ - class_: Type[Session] + class_: Type[_S] def __init__( self, bind: Optional[_SessionBind] = None, - class_: Type[Session] = Session, + class_: Type[_S] = Session, # type: ignore autoflush: bool = True, expire_on_commit: bool = True, info: Optional[_InfoType] = None, @@ -4545,7 +4549,7 @@ class sessionmaker(_SessionClassMethods): # events can be associated with it specifically. self.class_ = type(class_.__name__, (class_,), {}) - def begin(self) -> contextlib.AbstractContextManager[Session]: + def begin(self) -> contextlib.AbstractContextManager[_S]: """Produce a context manager that both provides a new :class:`_orm.Session` as well as a transaction that commits. @@ -4567,7 +4571,7 @@ class sessionmaker(_SessionClassMethods): session = self() return session._maker_context_manager() - def __call__(self, **local_kw: Any) -> Session: + def __call__(self, **local_kw: Any) -> _S: """Produce a new :class:`.Session` object using the configuration established in this :class:`.sessionmaker`. |