summaryrefslogtreecommitdiff
path: root/test/ext/mypy/plain_files/sessionmakers.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-05-26 14:35:03 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-05-31 15:17:48 -0400
commitd24cd5e96d7f8e47c86b5013a7f989a15e2eec89 (patch)
tree4291dbaeea6b78164e492da183cff5e8e7dfd9d6 /test/ext/mypy/plain_files/sessionmakers.py
parent5531cec630ee75bfd7f5848cfe622c769be5ae48 (diff)
downloadsqlalchemy-d24cd5e96d7f8e47c86b5013a7f989a15e2eec89.tar.gz
establish sessionmaker and async_sessionmaker as generic
This is so that custom Session and AsyncSession classes can be typed for these factories. Added appropriate typevars to `__call__()`, `__enter__()` and other methods so that a custom Session or AsyncSession subclass is carried through. Fixes: #7656 Change-Id: Ia2b8c1f22b4410db26005c3285f6ba3d13d7f0e0
Diffstat (limited to 'test/ext/mypy/plain_files/sessionmakers.py')
-rw-r--r--test/ext/mypy/plain_files/sessionmakers.py88
1 files changed, 88 insertions, 0 deletions
diff --git a/test/ext/mypy/plain_files/sessionmakers.py b/test/ext/mypy/plain_files/sessionmakers.py
new file mode 100644
index 000000000..ce9b76638
--- /dev/null
+++ b/test/ext/mypy/plain_files/sessionmakers.py
@@ -0,0 +1,88 @@
+"""test #7656"""
+
+from sqlalchemy import create_engine
+from sqlalchemy import Engine
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import async_sessionmaker
+from sqlalchemy.ext.asyncio import AsyncEngine
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.orm import scoped_session
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import sessionmaker
+
+
+async_engine = create_async_engine("...")
+
+
+class MyAsyncSession(AsyncSession):
+ pass
+
+
+def async_session_factory(
+ engine: AsyncEngine,
+) -> async_sessionmaker[MyAsyncSession]:
+ return async_sessionmaker(engine, class_=MyAsyncSession)
+
+
+def async_scoped_session_factory(
+ engine: AsyncEngine,
+) -> async_scoped_session[MyAsyncSession]:
+ return async_scoped_session(
+ async_sessionmaker(engine, class_=MyAsyncSession),
+ scopefunc=lambda: None,
+ )
+
+
+async def async_main() -> None:
+ fac = async_session_factory(async_engine)
+
+ async with fac() as sess:
+ # EXPECTED_TYPE: MyAsyncSession
+ reveal_type(sess)
+
+ async with fac.begin() as sess:
+ # EXPECTED_TYPE: MyAsyncSession
+ reveal_type(sess)
+
+ scoped_fac = async_scoped_session_factory(async_engine)
+
+ sess = scoped_fac()
+
+ # EXPECTED_TYPE: MyAsyncSession
+ reveal_type(sess)
+
+
+engine = create_engine("...")
+
+
+class MySession(Session):
+ pass
+
+
+def session_factory(
+ engine: Engine,
+) -> sessionmaker[MySession]:
+ return sessionmaker(engine, class_=MySession)
+
+
+def scoped_session_factory(engine: Engine) -> scoped_session[MySession]:
+ return scoped_session(sessionmaker(engine, class_=MySession))
+
+
+def main() -> None:
+ fac = session_factory(engine)
+
+ with fac() as sess:
+ # EXPECTED_TYPE: MySession
+ reveal_type(sess)
+
+ with fac.begin() as sess:
+ # EXPECTED_TYPE: MySession
+ reveal_type(sess)
+
+ scoped_fac = scoped_session_factory(engine)
+
+ sess = scoped_fac()
+ # EXPECTED_TYPE: MySession
+ reveal_type(sess)