diff options
-rw-r--r-- | doc/build/changelog/unreleased_20/9340.rst | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/_orm_constructors.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 2 | ||||
-rw-r--r-- | test/ext/mypy/plain_files/issue_9340.py | 63 |
4 files changed, 71 insertions, 2 deletions
diff --git a/doc/build/changelog/unreleased_20/9340.rst b/doc/build/changelog/unreleased_20/9340.rst new file mode 100644 index 000000000..28cef6f64 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9340.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, typing + :tickets: 9340 + + Fixed typing issue where :func:`_orm.with_polymorphic` would not + record the class type correctly. diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 3bd1db79d..64e7937f1 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -2208,7 +2208,7 @@ def aliased( def with_polymorphic( - base: Union[_O, Mapper[_O]], + base: Union[Type[_O], Mapper[_O]], classes: Union[Literal["*"], Iterable[Type[Any]]], selectable: Union[Literal[False, None], FromClause] = False, flat: bool = False, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ad9ce2013..1ef0d7159 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -994,7 +994,7 @@ class AliasedInsp( @classmethod def _with_polymorphic_factory( cls, - base: Union[_O, Mapper[_O]], + base: Union[Type[_O], Mapper[_O]], classes: Union[Literal["*"], Iterable[_EntityType[Any]]], selectable: Union[Literal[False, None], FromClause] = False, flat: bool = False, diff --git a/test/ext/mypy/plain_files/issue_9340.py b/test/ext/mypy/plain_files/issue_9340.py new file mode 100644 index 000000000..72dc72df1 --- /dev/null +++ b/test/ext/mypy/plain_files/issue_9340.py @@ -0,0 +1,63 @@ +from typing import Sequence +from typing import TYPE_CHECKING + +from sqlalchemy import create_engine +from sqlalchemy import select +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Session +from sqlalchemy.orm import with_polymorphic + + +class Base(DeclarativeBase): + ... + + +class Message(Base): + __tablename__ = "message" + __mapper_args__ = { + "polymorphic_on": "message_type", + "polymorphic_identity": __tablename__, + } + id: Mapped[int] = mapped_column(primary_key=True) + text: Mapped[str] + message_type: Mapped[str] + + +class UserComment(Message): + __mapper_args__ = { + "polymorphic_identity": "user_comment", + } + username: Mapped[str] = mapped_column(nullable=True) + + +engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/") + + +def get_messages() -> Sequence[Message]: + with Session(engine) as session: + message_query = select(Message) + + if TYPE_CHECKING: + # EXPECTED_TYPE: Select[Tuple[Message]] + reveal_type(message_query) + + return session.scalars(message_query).all() + + +def get_poly_messages() -> Sequence[Message]: + with Session(engine) as session: + PolymorphicMessage = with_polymorphic(Message, (UserComment,)) + + if TYPE_CHECKING: + # EXPECTED_TYPE: AliasedClass[Message] + reveal_type(PolymorphicMessage) + + poly_query = select(PolymorphicMessage) + + if TYPE_CHECKING: + # EXPECTED_TYPE: Select[Tuple[Message]] + reveal_type(poly_query) + + return session.scalars(poly_query).all() |