summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/session.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-05-26 14:35:03 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-05-31 15:17:48 -0400
commitd24cd5e96d7f8e47c86b5013a7f989a15e2eec89 (patch)
tree4291dbaeea6b78164e492da183cff5e8e7dfd9d6 /lib/sqlalchemy/orm/session.py
parent5531cec630ee75bfd7f5848cfe622c769be5ae48 (diff)
downloadsqlalchemy-d24cd5e96d7f8e47c86b5013a7f989a15e2eec89.tar.gz
establish sessionmaker and async_sessionmaker as generic
This is so that custom Session and AsyncSession classes can be typed for these factories. Added appropriate typevars to `__call__()`, `__enter__()` and other methods so that a custom Session or AsyncSession subclass is carried through. Fixes: #7656 Change-Id: Ia2b8c1f22b4410db26005c3285f6ba3d13d7f0e0
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r--lib/sqlalchemy/orm/session.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index d72e78c9e..788821b98 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -17,6 +17,7 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
+from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
@@ -1420,14 +1421,14 @@ class Session(_SessionClassMethods, EventTarget):
connection_callable: Optional[_ConnectionCallableProto] = None
- def __enter__(self) -> Session:
+ def __enter__(self: _S) -> _S:
return self
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
self.close()
@contextlib.contextmanager
- def _maker_context_manager(self) -> Iterator[Session]:
+ def _maker_context_manager(self: _S) -> Iterator[_S]:
with self:
with self.begin():
yield self
@@ -4398,7 +4399,10 @@ class Session(_SessionClassMethods, EventTarget):
return util.IdentitySet(list(self._new.values()))
-class sessionmaker(_SessionClassMethods):
+_S = TypeVar("_S", bound="Session")
+
+
+class sessionmaker(_SessionClassMethods, Generic[_S]):
"""A configurable :class:`.Session` factory.
The :class:`.sessionmaker` factory generates new
@@ -4493,12 +4497,12 @@ class sessionmaker(_SessionClassMethods):
"""
- class_: Type[Session]
+ class_: Type[_S]
def __init__(
self,
bind: Optional[_SessionBind] = None,
- class_: Type[Session] = Session,
+ class_: Type[_S] = Session, # type: ignore
autoflush: bool = True,
expire_on_commit: bool = True,
info: Optional[_InfoType] = None,
@@ -4545,7 +4549,7 @@ class sessionmaker(_SessionClassMethods):
# events can be associated with it specifically.
self.class_ = type(class_.__name__, (class_,), {})
- def begin(self) -> contextlib.AbstractContextManager[Session]:
+ def begin(self) -> contextlib.AbstractContextManager[_S]:
"""Produce a context manager that both provides a new
:class:`_orm.Session` as well as a transaction that commits.
@@ -4567,7 +4571,7 @@ class sessionmaker(_SessionClassMethods):
session = self()
return session._maker_context_manager()
- def __call__(self, **local_kw: Any) -> Session:
+ def __call__(self, **local_kw: Any) -> _S:
"""Produce a new :class:`.Session` object using the configuration
established in this :class:`.sessionmaker`.