summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/apply.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/apply.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/apply.py36
1 files changed, 26 insertions, 10 deletions
diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py
index 0f4bb1fd9..366260437 100644
--- a/lib/sqlalchemy/ext/mypy/apply.py
+++ b/lib/sqlalchemy/ext/mypy/apply.py
@@ -24,9 +24,12 @@ from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType
+from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneTyp
+from mypy.types import ProperType
from mypy.types import TypeOfAny
+from mypy.types import UnboundType
from mypy.types import UnionType
from . import util
@@ -37,7 +40,7 @@ def _apply_mypy_mapped_attr(
api: SemanticAnalyzerPluginInterface,
item: Union[NameExpr, StrExpr],
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
if isinstance(item, NameExpr):
name = item.name
elif isinstance(item, StrExpr):
@@ -46,7 +49,11 @@ def _apply_mypy_mapped_attr(
return
for stmt in cls.defs.body:
- if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name:
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
+ and stmt.lvalues[0].name == name
+ ):
break
else:
util.fail(api, "Can't find mapped attribute {}".format(name), cls)
@@ -61,7 +68,10 @@ def _apply_mypy_mapped_attr(
)
return
- left_hand_explicit_type = stmt.type
+ left_hand_explicit_type = get_proper_type(stmt.type)
+ assert isinstance(
+ left_hand_explicit_type, (Instance, UnionType, UnboundType)
+ )
cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
@@ -74,7 +84,7 @@ def _re_apply_declarative_assignments(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
"""For multiple class passes, re-apply our left-hand side types as mypy
seems to reset them in place.
@@ -90,7 +100,9 @@ def _re_apply_declarative_assignments(
# will change).
if (
isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name in mapped_attr_lookup
+ and isinstance(stmt.lvalues[0].node, Var)
):
typ = mapped_attr_lookup[stmt.lvalues[0].name]
left_node = stmt.lvalues[0].node
@@ -102,8 +114,8 @@ def _apply_type_to_mapped_statement(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
lvalue: NameExpr,
- left_hand_explicit_type: Optional[Union[Instance, UnionType]],
- python_type_for_type: Union[Instance, UnionType],
+ left_hand_explicit_type: Optional[ProperType],
+ python_type_for_type: Optional[ProperType],
) -> None:
"""Apply the Mapped[<type>] annotation and right hand object to a
declarative assignment statement.
@@ -124,6 +136,7 @@ def _apply_type_to_mapped_statement(
"""
left_node = lvalue.node
+ assert isinstance(left_node, Var)
if left_hand_explicit_type is not None:
left_node.type = api.named_type(
@@ -131,7 +144,10 @@ def _apply_type_to_mapped_statement(
)
else:
lvalue.is_inferred_def = False
- left_node.type = api.named_type("__sa_Mapped", [python_type_for_type])
+ left_node.type = api.named_type(
+ "__sa_Mapped",
+ [] if python_type_for_type is None else [python_type_for_type],
+ )
# so to have it skip the right side totally, we can do this:
# stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
@@ -146,7 +162,7 @@ def _apply_type_to_mapped_statement(
# the original right-hand side is maintained so it gets type checked
# internally
column_descriptor = nodes.NameExpr("__sa_Mapped")
- column_descriptor.fullname = "sqlalchemy.orm.Mapped"
+ column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
orig_call_expr = stmt.rvalue
stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"])
@@ -199,11 +215,11 @@ def _apply_placeholder_attr_to_class(
cls: ClassDef,
qualified_name: str,
attrname: str,
-):
+) -> None:
sym = api.lookup_fully_qualified_or_none(qualified_name)
if sym:
assert isinstance(sym.node, TypeInfo)
- type_: Union[Instance, AnyType] = Instance(sym.node, [])
+ type_: ProperType = Instance(sym.node, [])
else:
type_ = AnyType(TypeOfAny.special_form)
var = Var(attrname)