summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy
diff options
context:
space:
mode:
authorBryan Forbes <bryan@reigndropsfall.net>2021-04-12 16:24:37 -0500
committerBryan Forbes <bryan@reigndropsfall.net>2021-04-12 16:24:37 -0500
commite2008b5541cc155aea538317805e62ff1aa9b300 (patch)
tree04608c82131e8bb3aa2ada56c5e78d4e0a8936d5 /lib/sqlalchemy/ext/mypy
parentde7f14104d5278987fa72d6866fa39569e56077e (diff)
downloadsqlalchemy-e2008b5541cc155aea538317805e62ff1aa9b300.tar.gz
Update mypy plugin to conform to strict mode
Change-Id: I09a3df5af2f2d4ee34d8d72c3dedc4f236df8eb1
Diffstat (limited to 'lib/sqlalchemy/ext/mypy')
-rw-r--r--lib/sqlalchemy/ext/mypy/apply.py36
-rw-r--r--lib/sqlalchemy/ext/mypy/decl_class.py122
-rw-r--r--lib/sqlalchemy/ext/mypy/infer.py155
-rw-r--r--lib/sqlalchemy/ext/mypy/names.py122
-rw-r--r--lib/sqlalchemy/ext/mypy/plugin.py47
-rw-r--r--lib/sqlalchemy/ext/mypy/util.py115
6 files changed, 382 insertions, 215 deletions
diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py
index 0f4bb1fd9..366260437 100644
--- a/lib/sqlalchemy/ext/mypy/apply.py
+++ b/lib/sqlalchemy/ext/mypy/apply.py
@@ -24,9 +24,12 @@ from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType
+from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneTyp
+from mypy.types import ProperType
from mypy.types import TypeOfAny
+from mypy.types import UnboundType
from mypy.types import UnionType
from . import util
@@ -37,7 +40,7 @@ def _apply_mypy_mapped_attr(
api: SemanticAnalyzerPluginInterface,
item: Union[NameExpr, StrExpr],
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
if isinstance(item, NameExpr):
name = item.name
elif isinstance(item, StrExpr):
@@ -46,7 +49,11 @@ def _apply_mypy_mapped_attr(
return
for stmt in cls.defs.body:
- if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name:
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
+ and stmt.lvalues[0].name == name
+ ):
break
else:
util.fail(api, "Can't find mapped attribute {}".format(name), cls)
@@ -61,7 +68,10 @@ def _apply_mypy_mapped_attr(
)
return
- left_hand_explicit_type = stmt.type
+ left_hand_explicit_type = get_proper_type(stmt.type)
+ assert isinstance(
+ left_hand_explicit_type, (Instance, UnionType, UnboundType)
+ )
cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
@@ -74,7 +84,7 @@ def _re_apply_declarative_assignments(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
"""For multiple class passes, re-apply our left-hand side types as mypy
seems to reset them in place.
@@ -90,7 +100,9 @@ def _re_apply_declarative_assignments(
# will change).
if (
isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name in mapped_attr_lookup
+ and isinstance(stmt.lvalues[0].node, Var)
):
typ = mapped_attr_lookup[stmt.lvalues[0].name]
left_node = stmt.lvalues[0].node
@@ -102,8 +114,8 @@ def _apply_type_to_mapped_statement(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
lvalue: NameExpr,
- left_hand_explicit_type: Optional[Union[Instance, UnionType]],
- python_type_for_type: Union[Instance, UnionType],
+ left_hand_explicit_type: Optional[ProperType],
+ python_type_for_type: Optional[ProperType],
) -> None:
"""Apply the Mapped[<type>] annotation and right hand object to a
declarative assignment statement.
@@ -124,6 +136,7 @@ def _apply_type_to_mapped_statement(
"""
left_node = lvalue.node
+ assert isinstance(left_node, Var)
if left_hand_explicit_type is not None:
left_node.type = api.named_type(
@@ -131,7 +144,10 @@ def _apply_type_to_mapped_statement(
)
else:
lvalue.is_inferred_def = False
- left_node.type = api.named_type("__sa_Mapped", [python_type_for_type])
+ left_node.type = api.named_type(
+ "__sa_Mapped",
+ [] if python_type_for_type is None else [python_type_for_type],
+ )
# so to have it skip the right side totally, we can do this:
# stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
@@ -146,7 +162,7 @@ def _apply_type_to_mapped_statement(
# the original right-hand side is maintained so it gets type checked
# internally
column_descriptor = nodes.NameExpr("__sa_Mapped")
- column_descriptor.fullname = "sqlalchemy.orm.Mapped"
+ column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
orig_call_expr = stmt.rvalue
stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"])
@@ -199,11 +215,11 @@ def _apply_placeholder_attr_to_class(
cls: ClassDef,
qualified_name: str,
attrname: str,
-):
+) -> None:
sym = api.lookup_fully_qualified_or_none(qualified_name)
if sym:
assert isinstance(sym.node, TypeInfo)
- type_: Union[Instance, AnyType] = Instance(sym.node, [])
+ type_: ProperType = Instance(sym.node, [])
else:
type_ = AnyType(TypeOfAny.special_form)
var = Var(attrname)
diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py
index 40f1f0c0f..8fac36342 100644
--- a/lib/sqlalchemy/ext/mypy/decl_class.py
+++ b/lib/sqlalchemy/ext/mypy/decl_class.py
@@ -6,7 +6,7 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from typing import Optional
-from typing import Type
+from typing import Union
from mypy import nodes
from mypy.nodes import AssignmentStmt
@@ -14,18 +14,24 @@ from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Decorator
from mypy.nodes import ListExpr
+from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import PlaceholderNode
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
+from mypy.nodes import SymbolNode
from mypy.nodes import SymbolTableNode
from mypy.nodes import TempNode
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import AnyType
+from mypy.types import CallableType
+from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import Type
from mypy.types import TypeOfAny
from mypy.types import UnboundType
from mypy.types import UnionType
@@ -37,7 +43,9 @@ from . import util
def _scan_declarative_assignments_and_apply_types(
- cls: ClassDef, api: SemanticAnalyzerPluginInterface, is_mixin_scan=False
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ is_mixin_scan: bool = False,
) -> Optional[util.DeclClassApplied]:
info = util._info_for_cls(cls, api)
@@ -94,16 +102,17 @@ def _scan_symbol_table_entry(
name: str,
value: SymbolTableNode,
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
"""Extract mapping information from a SymbolTableNode that's in the
type.names dictionary.
"""
- if not isinstance(value.type, Instance):
+ value_type = get_proper_type(value.type)
+ if not isinstance(value_type, Instance):
return
left_hand_explicit_type = None
- type_id = names._type_id_for_named_node(value.type.type)
+ type_id = names._type_id_for_named_node(value_type.type)
# type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
err = False
@@ -118,22 +127,24 @@ def _scan_symbol_table_entry(
names.SYNONYM_PROPERTY,
names.COLUMN_PROPERTY,
}:
- if value.type.args:
- left_hand_explicit_type = value.type.args[0]
+ if value_type.args:
+ left_hand_explicit_type = get_proper_type(value_type.args[0])
else:
err = True
elif type_id is names.COLUMN:
- if not value.type.args:
+ if not value_type.args:
err = True
else:
- typeengine_arg = value.type.args[0]
+ typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
+ value_type.args[0]
+ )
if isinstance(typeengine_arg, Instance):
typeengine_arg = typeengine_arg.type
if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
- if sym is not None:
- if names._mro_has_id(sym.node.mro, names.TYPEENGINE):
+ if sym is not None and isinstance(sym.node, TypeInfo):
+ if names._has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
@@ -148,7 +159,7 @@ def _scan_symbol_table_entry(
api,
"Column type should be a TypeEngine "
"subclass not '{}'".format(sym.node.fullname),
- value.type,
+ value_type,
)
if err:
@@ -158,7 +169,7 @@ def _scan_symbol_table_entry(
"one of: Mapped[<python type>], relationship[<target class>], "
"Column[<TypeEngine>], MapperProperty[<python type>]"
)
- util.fail(api, msg.format(name, cls.name))
+ util.fail(api, msg.format(name, cls.name), cls)
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
@@ -171,7 +182,7 @@ def _scan_declarative_decorator_stmt(
api: SemanticAnalyzerPluginInterface,
stmt: Decorator,
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
"""Extract mapping information from a @declared_attr in a declarative
class.
@@ -195,16 +206,19 @@ def _scan_declarative_decorator_stmt(
"""
for dec in stmt.decorators:
- if names._type_id_for_named_node(dec) is names.DECLARED_ATTR:
+ if (
+ isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
+ and names._type_id_for_named_node(dec) is names.DECLARED_ATTR
+ ):
break
else:
return
dec_index = cls.defs.body.index(stmt)
- left_hand_explicit_type = None
+ left_hand_explicit_type: Optional[ProperType] = None
- if stmt.func.type is not None:
+ if isinstance(stmt.func.type, CallableType):
func_type = stmt.func.type.ret_type
if isinstance(func_type, UnboundType):
type_id = names._type_id_for_unbound_type(func_type, cls, api)
@@ -225,30 +239,28 @@ def _scan_declarative_decorator_stmt(
}
and func_type.args
):
- left_hand_explicit_type = func_type.args[0]
+ left_hand_explicit_type = get_proper_type(func_type.args[0])
elif type_id is names.COLUMN and func_type.args:
typeengine_arg = func_type.args[0]
if isinstance(typeengine_arg, UnboundType):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
- if sym is not None and names._mro_has_id(
- sym.node.mro, names.TYPEENGINE
- ):
-
- left_hand_explicit_type = UnionType(
- [
- infer._extract_python_type_from_typeengine(
- api, sym.node, []
- ),
- NoneType(),
- ]
- )
- else:
- util.fail(
- api,
- "Column type should be a TypeEngine "
- "subclass not '{}'".format(sym.node.fullname),
- func_type,
- )
+ if sym is not None and isinstance(sym.node, TypeInfo):
+ if names._has_base_type_id(sym.node, names.TYPEENGINE):
+ left_hand_explicit_type = UnionType(
+ [
+ infer._extract_python_type_from_typeengine(
+ api, sym.node, []
+ ),
+ NoneType(),
+ ]
+ )
+ else:
+ util.fail(
+ api,
+ "Column type should be a TypeEngine "
+ "subclass not '{}'".format(sym.node.fullname),
+ func_type,
+ )
if left_hand_explicit_type is None:
# no type on the decorated function. our option here is to
@@ -274,8 +286,8 @@ def _scan_declarative_decorator_stmt(
# of converting it to the regular Instance/TypeInfo/UnionType structures
# we see everywhere else.
if isinstance(left_hand_explicit_type, UnboundType):
- left_hand_explicit_type = util._unbound_to_instance(
- api, left_hand_explicit_type
+ left_hand_explicit_type = get_proper_type(
+ util._unbound_to_instance(api, left_hand_explicit_type)
)
left_node.node.type = api.named_type(
@@ -315,7 +327,7 @@ def _scan_declarative_assignment_stmt(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
"""Extract mapping information from an assignment statement in a
declarative class.
@@ -339,7 +351,7 @@ def _scan_declarative_assignment_stmt(
assert isinstance(node, Var)
if node.name == "__abstract__":
- if stmt.rvalue.fullname == "builtins.True":
+ if api.parse_bool(stmt.rvalue) is True:
cls_metadata.is_mapped = False
return
elif node.name == "__tablename__":
@@ -354,7 +366,8 @@ def _scan_declarative_assignment_stmt(
if isinstance(item, (NameExpr, StrExpr)):
apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata)
- left_hand_mapped_type: Type = None
+ left_hand_mapped_type: Optional[Type] = None
+ left_hand_explicit_type: Optional[ProperType] = None
if node.is_inferred or node.type is None:
if isinstance(stmt.type, UnboundType):
@@ -370,32 +383,33 @@ def _scan_declarative_assignment_stmt(
mapped_sym = api.lookup_qualified("Mapped", cls)
if (
mapped_sym is not None
+ and mapped_sym.node is not None
and names._type_id_for_named_node(mapped_sym.node)
is names.MAPPED
):
- left_hand_explicit_type = stmt.type.args[0]
+ left_hand_explicit_type = get_proper_type(
+ stmt.type.args[0]
+ )
left_hand_mapped_type = stmt.type
# TODO: do we need to convert from unbound for this case?
# left_hand_explicit_type = util._unbound_to_instance(
# api, left_hand_explicit_type
# )
-
- else:
- left_hand_explicit_type = None
else:
+ node_type = get_proper_type(node.type)
if (
- isinstance(node.type, Instance)
- and names._type_id_for_named_node(node.type.type) is names.MAPPED
+ isinstance(node_type, Instance)
+ and names._type_id_for_named_node(node_type.type) is names.MAPPED
):
# print(node.type)
# sqlalchemy.orm.attributes.Mapped[<python type>]
- left_hand_explicit_type = node.type.args[0]
- left_hand_mapped_type = node.type
+ left_hand_explicit_type = get_proper_type(node_type.args[0])
+ left_hand_mapped_type = node_type
else:
# print(node.type)
# <python type>
- left_hand_explicit_type = node.type
+ left_hand_explicit_type = node_type
left_hand_mapped_type = None
if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
@@ -440,10 +454,10 @@ def _scan_declarative_assignment_stmt(
else:
return
- cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
-
assert python_type_for_type is not None
+ cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
+
apply._apply_type_to_mapped_statement(
api,
stmt,
@@ -485,6 +499,6 @@ def _scan_for_mapped_bases(
)
)
- if base_decl_class_applied not in (None, False):
+ if base_decl_class_applied is not None:
cls_metadata.mapped_mro.append(base)
baseclasses.extend(base.type.bases)
diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py
index f1bda7865..7915c3ae2 100644
--- a/lib/sqlalchemy/ext/mypy/infer.py
+++ b/lib/sqlalchemy/ext/mypy/infer.py
@@ -6,23 +6,26 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from typing import Optional
-from typing import Union
+from typing import Sequence
-from mypy import nodes
-from mypy import types
from mypy.maptype import map_instance_to_supertype
from mypy.messages import format_type
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
+from mypy.nodes import Expression
+from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.subtypes import is_subtype
from mypy.types import AnyType
+from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
+from mypy.types import ProperType
from mypy.types import TypeOfAny
from mypy.types import UnionType
@@ -34,8 +37,8 @@ def _infer_type_from_relationship(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
"""Infer the type of mapping from a relationship.
E.g.::
@@ -62,7 +65,7 @@ def _infer_type_from_relationship(
assert isinstance(stmt.rvalue, CallExpr)
target_cls_arg = stmt.rvalue.args[0]
- python_type_for_type = None
+ python_type_for_type: Optional[ProperType] = None
if isinstance(target_cls_arg, NameExpr) and isinstance(
target_cls_arg.node, TypeInfo
@@ -86,7 +89,7 @@ def _infer_type_from_relationship(
# isinstance(target_cls_arg, StrExpr)
uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist")
- collection_cls_arg = util._get_callexpr_kwarg(
+ collection_cls_arg: Optional[Expression] = util._get_callexpr_kwarg(
stmt.rvalue, "collection_class"
)
type_is_a_collection = False
@@ -98,7 +101,7 @@ def _infer_type_from_relationship(
if (
uselist_arg is not None
- and uselist_arg.fullname == "builtins.True"
+ and api.parse_bool(uselist_arg) is True
and collection_cls_arg is None
):
type_is_a_collection = True
@@ -107,7 +110,7 @@ def _infer_type_from_relationship(
"__builtins__.list", [python_type_for_type]
)
elif (
- uselist_arg is None or uselist_arg.fullname == "builtins.True"
+ uselist_arg is None or api.parse_bool(uselist_arg) is True
) and collection_cls_arg is not None:
type_is_a_collection = True
if isinstance(collection_cls_arg, CallExpr):
@@ -130,7 +133,7 @@ def _infer_type_from_relationship(
stmt.rvalue,
)
python_type_for_type = None
- elif uselist_arg is not None and uselist_arg.fullname == "builtins.False":
+ elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
if collection_cls_arg is not None:
util.fail(
api,
@@ -159,13 +162,19 @@ def _infer_type_from_relationship(
api, node, left_hand_explicit_type
)
elif left_hand_explicit_type is not None:
- return _infer_type_from_left_and_inferred_right(
- api,
- node,
- left_hand_explicit_type,
- python_type_for_type,
- type_is_a_collection=type_is_a_collection,
- )
+ if type_is_a_collection:
+ assert isinstance(left_hand_explicit_type, Instance)
+ assert isinstance(python_type_for_type, Instance)
+ return _infer_collection_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return _infer_type_from_left_and_inferred_right(
+ api,
+ node,
+ left_hand_explicit_type,
+ python_type_for_type,
+ )
else:
return python_type_for_type
@@ -174,8 +183,8 @@ def _infer_type_from_decl_composite_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
"""Infer the type of mapping from a CompositeProperty."""
assert isinstance(stmt.rvalue, CallExpr)
@@ -206,8 +215,8 @@ def _infer_type_from_decl_column_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
"""Infer the type of mapping from a ColumnProperty.
This includes mappings against ``column_property()`` as well as the
@@ -219,28 +228,26 @@ def _infer_type_from_decl_column_property(
if isinstance(first_prop_arg, CallExpr):
type_id = names._type_id_for_callee(first_prop_arg.callee)
- else:
- type_id = None
- # look for column_property() / deferred() etc with Column as first
- # argument
- if type_id is names.COLUMN:
- return _infer_type_from_decl_column(
- api, stmt, node, left_hand_explicit_type, first_prop_arg
- )
- else:
- return _infer_type_from_left_hand_type_only(
- api, node, left_hand_explicit_type
- )
+ # look for column_property() / deferred() etc with Column as first
+ # argument
+ if type_id is names.COLUMN:
+ return _infer_type_from_decl_column(
+ api, stmt, node, left_hand_explicit_type, first_prop_arg
+ )
+
+ return _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
def _infer_type_from_decl_column(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
+ left_hand_explicit_type: Optional[ProperType],
right_hand_expression: CallExpr,
-) -> Union[Instance, UnionType, None]:
+) -> Optional[ProperType]:
"""Infer the type of mapping from a Column.
E.g.::
@@ -277,12 +284,13 @@ def _infer_type_from_decl_column(
callee = None
for column_arg in right_hand_expression.args[0:2]:
- if isinstance(column_arg, nodes.CallExpr):
- # x = Column(String(50))
- callee = column_arg.callee
- type_args = column_arg.args
- break
- elif isinstance(column_arg, (nodes.NameExpr, nodes.MemberExpr)):
+ if isinstance(column_arg, CallExpr):
+ if isinstance(column_arg.callee, RefExpr):
+ # x = Column(String(50))
+ callee = column_arg.callee
+ type_args: Sequence[Expression] = column_arg.args
+ break
+ elif isinstance(column_arg, (NameExpr, MemberExpr)):
if isinstance(column_arg.node, TypeInfo):
# x = Column(String)
callee = column_arg
@@ -314,10 +322,7 @@ def _infer_type_from_decl_column(
)
else:
- python_type_for_type = UnionType(
- [python_type_for_type, NoneType()]
- )
- return python_type_for_type
+ return UnionType([python_type_for_type, NoneType()])
else:
# it's not TypeEngine, it's typically implicitly typed
# like ForeignKey. we can't infer from the right side.
@@ -329,10 +334,11 @@ def _infer_type_from_decl_column(
def _infer_type_from_left_and_inferred_right(
api: SemanticAnalyzerPluginInterface,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
- python_type_for_type: Union[Instance, UnionType],
- type_is_a_collection: bool = False,
-) -> Optional[Union[Instance, UnionType]]:
+ left_hand_explicit_type: ProperType,
+ python_type_for_type: ProperType,
+ orig_left_hand_type: Optional[ProperType] = None,
+ orig_python_type_for_type: Optional[ProperType] = None,
+) -> Optional[ProperType]:
"""Validate type when a left hand annotation is present and we also
could infer the right hand side::
@@ -340,12 +346,10 @@ def _infer_type_from_left_and_inferred_right(
"""
- orig_left_hand_type = left_hand_explicit_type
- orig_python_type_for_type = python_type_for_type
-
- if type_is_a_collection and left_hand_explicit_type.args:
- left_hand_explicit_type = left_hand_explicit_type.args[0]
- python_type_for_type = python_type_for_type.args[0]
+ if orig_left_hand_type is None:
+ orig_left_hand_type = left_hand_explicit_type
+ if orig_python_type_for_type is None:
+ orig_python_type_for_type = python_type_for_type
if not is_subtype(left_hand_explicit_type, python_type_for_type):
effective_type = api.named_type(
@@ -369,11 +373,40 @@ def _infer_type_from_left_and_inferred_right(
return orig_left_hand_type
+def _infer_collection_type_from_left_and_inferred_right(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Instance,
+ python_type_for_type: Instance,
+) -> Optional[ProperType]:
+ orig_left_hand_type = left_hand_explicit_type
+ orig_python_type_for_type = python_type_for_type
+
+ if left_hand_explicit_type.args:
+ left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
+ python_type_arg = get_proper_type(python_type_for_type.args[0])
+ else:
+ left_hand_arg = left_hand_explicit_type
+ python_type_arg = python_type_for_type
+
+ assert isinstance(left_hand_arg, (Instance, UnionType))
+ assert isinstance(python_type_arg, (Instance, UnionType))
+
+ return _infer_type_from_left_and_inferred_right(
+ api,
+ node,
+ left_hand_arg,
+ python_type_arg,
+ orig_left_hand_type=orig_left_hand_type,
+ orig_python_type_for_type=orig_python_type_for_type,
+ )
+
+
def _infer_type_from_left_hand_type_only(
api: SemanticAnalyzerPluginInterface,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Optional[Union[Instance, UnionType]]:
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
"""Determine the type based on explicit annotation only.
if no annotation were present, note that we need one there to know
@@ -397,8 +430,10 @@ def _infer_type_from_left_hand_type_only(
def _extract_python_type_from_typeengine(
- api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args
-) -> Instance:
+ api: SemanticAnalyzerPluginInterface,
+ node: TypeInfo,
+ type_args: Sequence[Expression],
+) -> ProperType:
if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
first_arg = type_args[0]
if isinstance(first_arg, NameExpr) and isinstance(
@@ -426,4 +461,4 @@ def _extract_python_type_from_typeengine(
Instance(node, []),
type_engine_sym.node,
)
- return type_engine.args[-1]
+ return get_proper_type(type_engine.args[-1])
diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py
index 174a8f422..6ee600cd7 100644
--- a/lib/sqlalchemy/ext/mypy/names.py
+++ b/lib/sqlalchemy/ext/mypy/names.py
@@ -5,40 +5,48 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+from typing import Dict
from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Union
from mypy.nodes import ClassDef
from mypy.nodes import Expression
from mypy.nodes import FuncDef
-from mypy.nodes import RefExpr
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
from mypy.nodes import SymbolNode
from mypy.nodes import TypeAlias
from mypy.nodes import TypeInfo
-from mypy.nodes import Union
from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
from mypy.types import UnboundType
from ... import util
-COLUMN = util.symbol("COLUMN")
-RELATIONSHIP = util.symbol("RELATIONSHIP")
-REGISTRY = util.symbol("REGISTRY")
-COLUMN_PROPERTY = util.symbol("COLUMN_PROPERTY")
-TYPEENGINE = util.symbol("TYPEENGNE")
-MAPPED = util.symbol("MAPPED")
-DECLARATIVE_BASE = util.symbol("DECLARATIVE_BASE")
-DECLARATIVE_META = util.symbol("DECLARATIVE_META")
-MAPPED_DECORATOR = util.symbol("MAPPED_DECORATOR")
-COLUMN_PROPERTY = util.symbol("COLUMN_PROPERTY")
-SYNONYM_PROPERTY = util.symbol("SYNONYM_PROPERTY")
-COMPOSITE_PROPERTY = util.symbol("COMPOSITE_PROPERTY")
-DECLARED_ATTR = util.symbol("DECLARED_ATTR")
-MAPPER_PROPERTY = util.symbol("MAPPER_PROPERTY")
-AS_DECLARATIVE = util.symbol("AS_DECLARATIVE")
-AS_DECLARATIVE_BASE = util.symbol("AS_DECLARATIVE_BASE")
-DECLARATIVE_MIXIN = util.symbol("DECLARATIVE_MIXIN")
-
-_lookup = {
+COLUMN: int = util.symbol("COLUMN") # type: ignore
+RELATIONSHIP: int = util.symbol("RELATIONSHIP") # type: ignore
+REGISTRY: int = util.symbol("REGISTRY") # type: ignore
+COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
+TYPEENGINE: int = util.symbol("TYPEENGNE") # type: ignore
+MAPPED: int = util.symbol("MAPPED") # type: ignore
+DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") # type: ignore
+DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") # type: ignore
+MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") # type: ignore
+COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
+SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") # type: ignore
+COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") # type: ignore
+DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") # type: ignore
+MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") # type: ignore
+AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") # type: ignore
+AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") # type: ignore
+DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") # type: ignore
+
+_lookup: Dict[str, Tuple[int, Set[str]]] = {
"Column": (
COLUMN,
{
@@ -145,7 +153,21 @@ _lookup = {
}
-def _mro_has_id(mro: List[TypeInfo], type_id: int):
+def _has_base_type_id(info: TypeInfo, type_id: int) -> bool:
+ for mr in info.mro:
+ check_type_id, fullnames = _lookup.get(mr.name, (None, None))
+ if check_type_id == type_id:
+ break
+ else:
+ return False
+
+ if fullnames is None:
+ return False
+
+ return mr.fullname in fullnames
+
+
+def _mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
for mr in mro:
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
if check_type_id == type_id:
@@ -153,65 +175,75 @@ def _mro_has_id(mro: List[TypeInfo], type_id: int):
else:
return False
+ if fullnames is None:
+ return False
+
return mr.fullname in fullnames
def _type_id_for_unbound_type(
type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
-) -> int:
+) -> Optional[int]:
type_id = None
sym = api.lookup_qualified(type_.name, type_)
if sym is not None:
if isinstance(sym.node, TypeAlias):
- type_id = _type_id_for_named_node(sym.node.target.type)
+ target_type = get_proper_type(sym.node.target)
+ if isinstance(target_type, Instance):
+ type_id = _type_id_for_named_node(target_type.type)
elif isinstance(sym.node, TypeInfo):
type_id = _type_id_for_named_node(sym.node)
return type_id
-def _type_id_for_callee(callee: Expression) -> int:
- if isinstance(callee.node, FuncDef):
- return _type_id_for_funcdef(callee.node)
- elif isinstance(callee.node, TypeAlias):
- type_id = _type_id_for_fullname(callee.node.target.type.fullname)
- elif isinstance(callee.node, TypeInfo):
- type_id = _type_id_for_named_node(callee)
- else:
- type_id = None
+def _type_id_for_callee(callee: Expression) -> Optional[int]:
+ if isinstance(callee, (MemberExpr, NameExpr)):
+ if isinstance(callee.node, FuncDef):
+ return _type_id_for_funcdef(callee.node)
+ elif isinstance(callee.node, TypeAlias):
+ target_type = get_proper_type(callee.node.target)
+ if isinstance(target_type, Instance):
+ type_id = _type_id_for_fullname(target_type.type.fullname)
+ elif isinstance(callee.node, TypeInfo):
+ type_id = _type_id_for_named_node(callee)
+ else:
+ type_id = None
return type_id
-def _type_id_for_funcdef(node: FuncDef) -> int:
- if hasattr(node.type.ret_type, "type"):
- type_id = _type_id_for_fullname(node.type.ret_type.type.fullname)
- else:
- type_id = None
- return type_id
+def _type_id_for_funcdef(node: FuncDef) -> Optional[int]:
+ if node.type and isinstance(node.type, CallableType):
+ ret_type = get_proper_type(node.type.ret_type)
+
+ if isinstance(ret_type, Instance):
+ return _type_id_for_fullname(ret_type.type.fullname)
+
+ return None
-def _type_id_for_named_node(node: Union[RefExpr, SymbolNode]) -> int:
+def _type_id_for_named_node(
+ node: Union[NameExpr, MemberExpr, SymbolNode]
+) -> Optional[int]:
type_id, fullnames = _lookup.get(node.name, (None, None))
- if type_id is None:
+ if type_id is None or fullnames is None:
return None
-
elif node.fullname in fullnames:
return type_id
else:
return None
-def _type_id_for_fullname(fullname: str) -> int:
+def _type_id_for_fullname(fullname: str) -> Optional[int]:
tokens = fullname.split(".")
immediate = tokens[-1]
type_id, fullnames = _lookup.get(immediate, (None, None))
- if type_id is None:
+ if type_id is None or fullnames is None:
return None
-
elif fullname in fullnames:
return type_id
else:
diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py
index 23585be49..76aac5152 100644
--- a/lib/sqlalchemy/ext/mypy/plugin.py
+++ b/lib/sqlalchemy/ext/mypy/plugin.py
@@ -9,9 +9,12 @@
Mypy plugin for SQLAlchemy ORM.
"""
+from typing import Callable
from typing import List
+from typing import Optional
from typing import Tuple
-from typing import Type
+from typing import Type as TypingType
+from typing import Union
from mypy import nodes
from mypy.mro import calculate_mro
@@ -25,20 +28,20 @@ from mypy.nodes import SymbolTable
from mypy.nodes import SymbolTableNode
from mypy.nodes import TypeInfo
from mypy.plugin import AttributeContext
-from mypy.plugin import Callable
from mypy.plugin import ClassDefContext
from mypy.plugin import DynamicClassDefContext
-from mypy.plugin import Optional
from mypy.plugin import Plugin
from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import get_proper_type
from mypy.types import Instance
+from mypy.types import Type
from . import decl_class
from . import names
from . import util
-class CustomPlugin(Plugin):
+class SQLAlchemyPlugin(Plugin):
def get_dynamic_class_hook(
self, fullname: str
) -> Optional[Callable[[DynamicClassDefContext], None]]:
@@ -72,7 +75,7 @@ class CustomPlugin(Plugin):
sym = self.lookup_fully_qualified(fullname)
- if sym is not None:
+ if sym is not None and sym.node is not None:
type_id = names._type_id_for_named_node(sym.node)
if type_id is names.MAPPED_DECORATOR:
return _cls_decorator_hook
@@ -109,8 +112,8 @@ class CustomPlugin(Plugin):
]
-def plugin(version: str):
- return CustomPlugin
+def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
+ return SQLAlchemyPlugin
def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
@@ -143,14 +146,14 @@ def _fill_in_decorators(ctx: ClassDefContext) -> None:
else:
continue
+ assert isinstance(target.expr, NameExpr)
sym = ctx.api.lookup_qualified(
target.expr.name, target, suppress_errors=True
)
- if sym:
- if sym.node.type and hasattr(sym.node.type, "type"):
- target.fullname = (
- f"{sym.node.type.type.fullname}.{target.name}"
- )
+ if sym and sym.node:
+ sym_type = get_proper_type(sym.type)
+ if isinstance(sym_type, Instance):
+ target.fullname = f"{sym_type.type.fullname}.{target.name}"
else:
# if the registry is in the same file as where the
# decorator is used, it might not have semantic
@@ -170,7 +173,7 @@ def _fill_in_decorators(ctx: ClassDefContext) -> None:
)
-def _add_globals(ctx: ClassDefContext):
+def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
"""Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
for all class defs
@@ -207,7 +210,15 @@ def _cls_decorator_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
assert isinstance(ctx.reason, nodes.MemberExpr)
expr = ctx.reason.expr
- assert names._type_id_for_named_node(expr.node.type.type) is names.REGISTRY
+
+ assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
+
+ node_type = get_proper_type(expr.node.type)
+
+ assert (
+ isinstance(node_type, Instance)
+ and names._type_id_for_named_node(node_type.type) is names.REGISTRY
+ )
decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
@@ -237,8 +248,8 @@ def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
cls.info = info
_make_declarative_meta(ctx.api, cls)
- cls_arg = util._get_callexpr_kwarg(ctx.call, "cls")
- if cls_arg is not None:
+ cls_arg = util._get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
+ if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
decl_class._scan_declarative_assignments_and_apply_types(
cls_arg.node.defn, ctx.api, is_mixin_scan=True
)
@@ -263,7 +274,7 @@ def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
def _make_declarative_meta(
api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
-):
+) -> None:
declarative_meta_name: NameExpr = NameExpr("__sa_DeclarativeMeta")
declarative_meta_name.kind = GDEF
@@ -272,6 +283,8 @@ def _make_declarative_meta(
# installed by _add_globals
sym = api.lookup_qualified("__sa_DeclarativeMeta", target_cls)
+ assert sym is not None and isinstance(sym.node, nodes.TypeInfo)
+
declarative_meta_typeinfo = sym.node
declarative_meta_name.node = declarative_meta_typeinfo
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
index 1c1e56d2c..26bb0ac67 100644
--- a/lib/sqlalchemy/ext/mypy/util.py
+++ b/lib/sqlalchemy/ext/mypy/util.py
@@ -1,37 +1,52 @@
+from typing import Any
+from typing import cast
+from typing import Iterable
+from typing import Iterator
+from typing import List
from typing import Optional
-from typing import Sequence
+from typing import overload
from typing import Tuple
-from typing import Type
+from typing import Type as TypingType
+from typing import TypeVar
+from typing import Union
from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
from mypy.nodes import IfStmt
from mypy.nodes import JsonDict
from mypy.nodes import NameExpr
+from mypy.nodes import Statement
from mypy.nodes import SymbolTableNode
from mypy.nodes import TypeInfo
from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import deserialize_and_fixup_type
from mypy.types import Instance
from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import Type
from mypy.types import UnboundType
from mypy.types import UnionType
+_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
+
+
class DeclClassApplied:
def __init__(
self,
is_mapped: bool,
has_table: bool,
- mapped_attr_names: Sequence[Tuple[str, Type]],
- mapped_mro: Sequence[Type],
+ mapped_attr_names: Iterable[Tuple[str, ProperType]],
+ mapped_mro: Iterable[Instance],
):
self.is_mapped = is_mapped
self.has_table = has_table
- self.mapped_attr_names = mapped_attr_names
- self.mapped_mro = mapped_mro
+ self.mapped_attr_names = list(mapped_attr_names)
+ self.mapped_mro = list(mapped_mro)
def serialize(self) -> JsonDict:
return {
@@ -52,28 +67,34 @@ class DeclClassApplied:
return DeclClassApplied(
is_mapped=data["is_mapped"],
has_table=data["has_table"],
- mapped_attr_names=[
- (name, deserialize_and_fixup_type(type_, api))
- for name, type_ in data["mapped_attr_names"]
- ],
- mapped_mro=[
- deserialize_and_fixup_type(type_, api)
- for type_ in data["mapped_mro"]
- ],
+ mapped_attr_names=cast(
+ List[Tuple[str, ProperType]],
+ [
+ (name, deserialize_and_fixup_type(type_, api))
+ for name, type_ in data["mapped_attr_names"]
+ ],
+ ),
+ mapped_mro=cast(
+ List[Instance],
+ [
+ deserialize_and_fixup_type(type_, api)
+ for type_ in data["mapped_mro"]
+ ],
+ ),
)
-def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context):
+def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
msg = "[SQLAlchemy Mypy plugin] %s" % msg
return api.fail(msg, ctx)
def add_global(
- ctx: ClassDefContext,
+ ctx: Union[ClassDefContext, DynamicClassDefContext],
module: str,
symbol_name: str,
asname: str,
-):
+) -> None:
module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
if asname not in module_globals:
@@ -84,18 +105,50 @@ def add_global(
module_globals[asname] = lookup_sym
-def _get_callexpr_kwarg(callexpr: CallExpr, name: str) -> Optional[NameExpr]:
+@overload
+def _get_callexpr_kwarg(
+ callexpr: CallExpr, name: str, *, expr_types: None = ...
+) -> Optional[Union[CallExpr, NameExpr]]:
+ ...
+
+
+@overload
+def _get_callexpr_kwarg(
+ callexpr: CallExpr,
+ name: str,
+ *,
+ expr_types: Tuple[TypingType[_TArgType], ...]
+) -> Optional[_TArgType]:
+ ...
+
+
+def _get_callexpr_kwarg(
+ callexpr: CallExpr,
+ name: str,
+ *,
+ expr_types: Optional[Tuple[TypingType[Any], ...]] = None
+) -> Optional[Any]:
try:
arg_idx = callexpr.arg_names.index(name)
except ValueError:
return None
- return callexpr.args[arg_idx]
+ kwarg = callexpr.args[arg_idx]
+ if isinstance(
+ kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
+ ):
+ return kwarg
+ return None
-def _flatten_typechecking(stmts):
+
+def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
for stmt in stmts:
- if isinstance(stmt, IfStmt) and stmt.expr[0].name == "TYPE_CHECKING":
+ if (
+ isinstance(stmt, IfStmt)
+ and isinstance(stmt.expr[0], NameExpr)
+ and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
+ ):
for substmt in stmt.body[0].body:
yield substmt
else:
@@ -103,7 +156,7 @@ def _flatten_typechecking(stmts):
def _unbound_to_instance(
- api: SemanticAnalyzerPluginInterface, typ: UnboundType
+ api: SemanticAnalyzerPluginInterface, typ: Type
) -> Type:
"""Take the UnboundType that we seem to get as the ret_type from a FuncDef
and convert it into an Instance/TypeInfo kind of structure that seems
@@ -130,7 +183,11 @@ def _unbound_to_instance(
node = api.lookup_qualified(typ.name, typ)
- if node is not None and isinstance(node, SymbolTableNode):
+ if (
+ node is not None
+ and isinstance(node, SymbolTableNode)
+ and isinstance(node.node, TypeInfo)
+ ):
bound_type = node.node
return Instance(
@@ -146,12 +203,12 @@ def _unbound_to_instance(
return typ
-def _info_for_cls(cls, api):
+def _info_for_cls(
+ cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> TypeInfo:
if cls.info is CLASSDEF_NO_INFO:
sym = api.lookup_qualified(cls.name, cls)
- if sym.node and isinstance(sym.node, TypeInfo):
- info = sym.node
- else:
- info = cls.info
+ assert sym and isinstance(sym.node, TypeInfo)
+ return sym.node
- return info
+ return cls.info