diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-26 14:35:03 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-31 15:17:48 -0400 |
commit | d24cd5e96d7f8e47c86b5013a7f989a15e2eec89 (patch) | |
tree | 4291dbaeea6b78164e492da183cff5e8e7dfd9d6 /test/ext/mypy/plain_files/sessionmakers.py | |
parent | 5531cec630ee75bfd7f5848cfe622c769be5ae48 (diff) | |
download | sqlalchemy-d24cd5e96d7f8e47c86b5013a7f989a15e2eec89.tar.gz |
establish sessionmaker and async_sessionmaker as generic
This is so that custom Session and AsyncSession classes
can be typed for these factories. Added appropriate
typevars to `__call__()`, `__enter__()` and other methods
so that a custom Session or AsyncSession subclass is carried
through.
Fixes: #7656
Change-Id: Ia2b8c1f22b4410db26005c3285f6ba3d13d7f0e0
Diffstat (limited to 'test/ext/mypy/plain_files/sessionmakers.py')
-rw-r--r-- | test/ext/mypy/plain_files/sessionmakers.py | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/test/ext/mypy/plain_files/sessionmakers.py b/test/ext/mypy/plain_files/sessionmakers.py new file mode 100644 index 000000000..ce9b76638 --- /dev/null +++ b/test/ext/mypy/plain_files/sessionmakers.py @@ -0,0 +1,88 @@ +"""test #7656""" + +from sqlalchemy import create_engine +from sqlalchemy import Engine +from sqlalchemy.ext.asyncio import async_scoped_session +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker + + +async_engine = create_async_engine("...") + + +class MyAsyncSession(AsyncSession): + pass + + +def async_session_factory( + engine: AsyncEngine, +) -> async_sessionmaker[MyAsyncSession]: + return async_sessionmaker(engine, class_=MyAsyncSession) + + +def async_scoped_session_factory( + engine: AsyncEngine, +) -> async_scoped_session[MyAsyncSession]: + return async_scoped_session( + async_sessionmaker(engine, class_=MyAsyncSession), + scopefunc=lambda: None, + ) + + +async def async_main() -> None: + fac = async_session_factory(async_engine) + + async with fac() as sess: + # EXPECTED_TYPE: MyAsyncSession + reveal_type(sess) + + async with fac.begin() as sess: + # EXPECTED_TYPE: MyAsyncSession + reveal_type(sess) + + scoped_fac = async_scoped_session_factory(async_engine) + + sess = scoped_fac() + + # EXPECTED_TYPE: MyAsyncSession + reveal_type(sess) + + +engine = create_engine("...") + + +class MySession(Session): + pass + + +def session_factory( + engine: Engine, +) -> sessionmaker[MySession]: + return sessionmaker(engine, class_=MySession) + + +def scoped_session_factory(engine: Engine) -> scoped_session[MySession]: + return scoped_session(sessionmaker(engine, class_=MySession)) + + +def main() -> None: + fac = session_factory(engine) + + with fac() as sess: + # EXPECTED_TYPE: MySession + reveal_type(sess) + + with fac.begin() as sess: + # EXPECTED_TYPE: MySession + reveal_type(sess) + + scoped_fac = scoped_session_factory(engine) + + sess = scoped_fac() + # EXPECTED_TYPE: MySession + reveal_type(sess) |