summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/9340.rst6
-rw-r--r--lib/sqlalchemy/orm/_orm_constructors.py2
-rw-r--r--lib/sqlalchemy/orm/util.py2
-rw-r--r--test/ext/mypy/plain_files/issue_9340.py63
4 files changed, 71 insertions, 2 deletions
diff --git a/doc/build/changelog/unreleased_20/9340.rst b/doc/build/changelog/unreleased_20/9340.rst
new file mode 100644
index 000000000..28cef6f64
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9340.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: bug, typing
+ :tickets: 9340
+
+ Fixed typing issue where :func:`_orm.with_polymorphic` would not
+ record the class type correctly.
diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py
index 3bd1db79d..64e7937f1 100644
--- a/lib/sqlalchemy/orm/_orm_constructors.py
+++ b/lib/sqlalchemy/orm/_orm_constructors.py
@@ -2208,7 +2208,7 @@ def aliased(
def with_polymorphic(
- base: Union[_O, Mapper[_O]],
+ base: Union[Type[_O], Mapper[_O]],
classes: Union[Literal["*"], Iterable[Type[Any]]],
selectable: Union[Literal[False, None], FromClause] = False,
flat: bool = False,
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index ad9ce2013..1ef0d7159 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -994,7 +994,7 @@ class AliasedInsp(
@classmethod
def _with_polymorphic_factory(
cls,
- base: Union[_O, Mapper[_O]],
+ base: Union[Type[_O], Mapper[_O]],
classes: Union[Literal["*"], Iterable[_EntityType[Any]]],
selectable: Union[Literal[False, None], FromClause] = False,
flat: bool = False,
diff --git a/test/ext/mypy/plain_files/issue_9340.py b/test/ext/mypy/plain_files/issue_9340.py
new file mode 100644
index 000000000..72dc72df1
--- /dev/null
+++ b/test/ext/mypy/plain_files/issue_9340.py
@@ -0,0 +1,63 @@
+from typing import Sequence
+from typing import TYPE_CHECKING
+
+from sqlalchemy import create_engine
+from sqlalchemy import select
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import with_polymorphic
+
+
+class Base(DeclarativeBase):
+ ...
+
+
+class Message(Base):
+ __tablename__ = "message"
+ __mapper_args__ = {
+ "polymorphic_on": "message_type",
+ "polymorphic_identity": __tablename__,
+ }
+ id: Mapped[int] = mapped_column(primary_key=True)
+ text: Mapped[str]
+ message_type: Mapped[str]
+
+
+class UserComment(Message):
+ __mapper_args__ = {
+ "polymorphic_identity": "user_comment",
+ }
+ username: Mapped[str] = mapped_column(nullable=True)
+
+
+engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/")
+
+
+def get_messages() -> Sequence[Message]:
+ with Session(engine) as session:
+ message_query = select(Message)
+
+ if TYPE_CHECKING:
+ # EXPECTED_TYPE: Select[Tuple[Message]]
+ reveal_type(message_query)
+
+ return session.scalars(message_query).all()
+
+
+def get_poly_messages() -> Sequence[Message]:
+ with Session(engine) as session:
+ PolymorphicMessage = with_polymorphic(Message, (UserComment,))
+
+ if TYPE_CHECKING:
+ # EXPECTED_TYPE: AliasedClass[Message]
+ reveal_type(PolymorphicMessage)
+
+ poly_query = select(PolymorphicMessage)
+
+ if TYPE_CHECKING:
+ # EXPECTED_TYPE: Select[Tuple[Message]]
+ reveal_type(poly_query)
+
+ return session.scalars(poly_query).all()