summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/apply.py
diff options
context:
space:
mode:
authorBryan Forbes <bryan@reigndropsfall.net>2021-04-12 16:24:37 -0500
committerBryan Forbes <bryan@reigndropsfall.net>2021-04-12 16:24:37 -0500
commite2008b5541cc155aea538317805e62ff1aa9b300 (patch)
tree04608c82131e8bb3aa2ada56c5e78d4e0a8936d5 /lib/sqlalchemy/ext/mypy/apply.py
parentde7f14104d5278987fa72d6866fa39569e56077e (diff)
downloadsqlalchemy-e2008b5541cc155aea538317805e62ff1aa9b300.tar.gz
Update mypy plugin to conform to strict mode
Change-Id: I09a3df5af2f2d4ee34d8d72c3dedc4f236df8eb1
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/apply.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/apply.py36
1 files changed, 26 insertions, 10 deletions
diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py
index 0f4bb1fd9..366260437 100644
--- a/lib/sqlalchemy/ext/mypy/apply.py
+++ b/lib/sqlalchemy/ext/mypy/apply.py
@@ -24,9 +24,12 @@ from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType
+from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneTyp
+from mypy.types import ProperType
from mypy.types import TypeOfAny
+from mypy.types import UnboundType
from mypy.types import UnionType
from . import util
@@ -37,7 +40,7 @@ def _apply_mypy_mapped_attr(
api: SemanticAnalyzerPluginInterface,
item: Union[NameExpr, StrExpr],
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
if isinstance(item, NameExpr):
name = item.name
elif isinstance(item, StrExpr):
@@ -46,7 +49,11 @@ def _apply_mypy_mapped_attr(
return
for stmt in cls.defs.body:
- if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name:
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
+ and stmt.lvalues[0].name == name
+ ):
break
else:
util.fail(api, "Can't find mapped attribute {}".format(name), cls)
@@ -61,7 +68,10 @@ def _apply_mypy_mapped_attr(
)
return
- left_hand_explicit_type = stmt.type
+ left_hand_explicit_type = get_proper_type(stmt.type)
+ assert isinstance(
+ left_hand_explicit_type, (Instance, UnionType, UnboundType)
+ )
cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
@@ -74,7 +84,7 @@ def _re_apply_declarative_assignments(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
"""For multiple class passes, re-apply our left-hand side types as mypy
seems to reset them in place.
@@ -90,7 +100,9 @@ def _re_apply_declarative_assignments(
# will change).
if (
isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name in mapped_attr_lookup
+ and isinstance(stmt.lvalues[0].node, Var)
):
typ = mapped_attr_lookup[stmt.lvalues[0].name]
left_node = stmt.lvalues[0].node
@@ -102,8 +114,8 @@ def _apply_type_to_mapped_statement(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
lvalue: NameExpr,
- left_hand_explicit_type: Optional[Union[Instance, UnionType]],
- python_type_for_type: Union[Instance, UnionType],
+ left_hand_explicit_type: Optional[ProperType],
+ python_type_for_type: Optional[ProperType],
) -> None:
"""Apply the Mapped[<type>] annotation and right hand object to a
declarative assignment statement.
@@ -124,6 +136,7 @@ def _apply_type_to_mapped_statement(
"""
left_node = lvalue.node
+ assert isinstance(left_node, Var)
if left_hand_explicit_type is not None:
left_node.type = api.named_type(
@@ -131,7 +144,10 @@ def _apply_type_to_mapped_statement(
)
else:
lvalue.is_inferred_def = False
- left_node.type = api.named_type("__sa_Mapped", [python_type_for_type])
+ left_node.type = api.named_type(
+ "__sa_Mapped",
+ [] if python_type_for_type is None else [python_type_for_type],
+ )
# so to have it skip the right side totally, we can do this:
# stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
@@ -146,7 +162,7 @@ def _apply_type_to_mapped_statement(
# the original right-hand side is maintained so it gets type checked
# internally
column_descriptor = nodes.NameExpr("__sa_Mapped")
- column_descriptor.fullname = "sqlalchemy.orm.Mapped"
+ column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
orig_call_expr = stmt.rvalue
stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"])
@@ -199,11 +215,11 @@ def _apply_placeholder_attr_to_class(
cls: ClassDef,
qualified_name: str,
attrname: str,
-):
+) -> None:
sym = api.lookup_fully_qualified_or_none(qualified_name)
if sym:
assert isinstance(sym.node, TypeInfo)
- type_: Union[Instance, AnyType] = Instance(sym.node, [])
+ type_: ProperType = Instance(sym.node, [])
else:
type_ = AnyType(TypeOfAny.special_form)
var = Var(attrname)