summaryrefslogtreecommitdiff
path: root/test/ext/mypy/plain_files/issue_9340.py
blob: 72dc72df1ecd668fefc859be48020420aa5f4b92 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Sequence
from typing import TYPE_CHECKING

from sqlalchemy import create_engine
from sqlalchemy import select
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import Session
from sqlalchemy.orm import with_polymorphic


class Base(DeclarativeBase):
    ...


class Message(Base):
    __tablename__ = "message"
    __mapper_args__ = {
        "polymorphic_on": "message_type",
        "polymorphic_identity": __tablename__,
    }
    id: Mapped[int] = mapped_column(primary_key=True)
    text: Mapped[str]
    message_type: Mapped[str]


class UserComment(Message):
    __mapper_args__ = {
        "polymorphic_identity": "user_comment",
    }
    username: Mapped[str] = mapped_column(nullable=True)


engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/")


def get_messages() -> Sequence[Message]:
    with Session(engine) as session:
        message_query = select(Message)

        if TYPE_CHECKING:
            # EXPECTED_TYPE: Select[Tuple[Message]]
            reveal_type(message_query)

        return session.scalars(message_query).all()


def get_poly_messages() -> Sequence[Message]:
    with Session(engine) as session:
        PolymorphicMessage = with_polymorphic(Message, (UserComment,))

        if TYPE_CHECKING:
            # EXPECTED_TYPE: AliasedClass[Message]
            reveal_type(PolymorphicMessage)

        poly_query = select(PolymorphicMessage)

        if TYPE_CHECKING:
            # EXPECTED_TYPE: Select[Tuple[Message]]
            reveal_type(poly_query)

        return session.scalars(poly_query).all()