diff options
author | Bryan Forbes <bryan@reigndropsfall.net> | 2021-04-12 16:24:37 -0500 |
---|---|---|
committer | Bryan Forbes <bryan@reigndropsfall.net> | 2021-04-12 16:24:37 -0500 |
commit | e2008b5541cc155aea538317805e62ff1aa9b300 (patch) | |
tree | 04608c82131e8bb3aa2ada56c5e78d4e0a8936d5 /lib/sqlalchemy/ext/mypy | |
parent | de7f14104d5278987fa72d6866fa39569e56077e (diff) | |
download | sqlalchemy-e2008b5541cc155aea538317805e62ff1aa9b300.tar.gz |
Update mypy plugin to conform to strict mode
Change-Id: I09a3df5af2f2d4ee34d8d72c3dedc4f236df8eb1
Diffstat (limited to 'lib/sqlalchemy/ext/mypy')
-rw-r--r-- | lib/sqlalchemy/ext/mypy/apply.py | 36 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/mypy/decl_class.py | 122 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/mypy/infer.py | 155 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/mypy/names.py | 122 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/mypy/plugin.py | 47 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/mypy/util.py | 115 |
6 files changed, 382 insertions, 215 deletions
diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index 0f4bb1fd9..366260437 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -24,9 +24,12 @@ from mypy.nodes import Var from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import add_method_to_class from mypy.types import AnyType +from mypy.types import get_proper_type from mypy.types import Instance from mypy.types import NoneTyp +from mypy.types import ProperType from mypy.types import TypeOfAny +from mypy.types import UnboundType from mypy.types import UnionType from . import util @@ -37,7 +40,7 @@ def _apply_mypy_mapped_attr( api: SemanticAnalyzerPluginInterface, item: Union[NameExpr, StrExpr], cls_metadata: util.DeclClassApplied, -): +) -> None: if isinstance(item, NameExpr): name = item.name elif isinstance(item, StrExpr): @@ -46,7 +49,11 @@ def _apply_mypy_mapped_attr( return for stmt in cls.defs.body: - if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name: + if ( + isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) + and stmt.lvalues[0].name == name + ): break else: util.fail(api, "Can't find mapped attribute {}".format(name), cls) @@ -61,7 +68,10 @@ def _apply_mypy_mapped_attr( ) return - left_hand_explicit_type = stmt.type + left_hand_explicit_type = get_proper_type(stmt.type) + assert isinstance( + left_hand_explicit_type, (Instance, UnionType, UnboundType) + ) cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type)) @@ -74,7 +84,7 @@ def _re_apply_declarative_assignments( cls: ClassDef, api: SemanticAnalyzerPluginInterface, cls_metadata: util.DeclClassApplied, -): +) -> None: """For multiple class passes, re-apply our left-hand side types as mypy seems to reset them in place. @@ -90,7 +100,9 @@ def _re_apply_declarative_assignments( # will change). if ( isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) and stmt.lvalues[0].name in mapped_attr_lookup + and isinstance(stmt.lvalues[0].node, Var) ): typ = mapped_attr_lookup[stmt.lvalues[0].name] left_node = stmt.lvalues[0].node @@ -102,8 +114,8 @@ def _apply_type_to_mapped_statement( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, lvalue: NameExpr, - left_hand_explicit_type: Optional[Union[Instance, UnionType]], - python_type_for_type: Union[Instance, UnionType], + left_hand_explicit_type: Optional[ProperType], + python_type_for_type: Optional[ProperType], ) -> None: """Apply the Mapped[<type>] annotation and right hand object to a declarative assignment statement. @@ -124,6 +136,7 @@ def _apply_type_to_mapped_statement( """ left_node = lvalue.node + assert isinstance(left_node, Var) if left_hand_explicit_type is not None: left_node.type = api.named_type( @@ -131,7 +144,10 @@ def _apply_type_to_mapped_statement( ) else: lvalue.is_inferred_def = False - left_node.type = api.named_type("__sa_Mapped", [python_type_for_type]) + left_node.type = api.named_type( + "__sa_Mapped", + [] if python_type_for_type is None else [python_type_for_type], + ) # so to have it skip the right side totally, we can do this: # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form)) @@ -146,7 +162,7 @@ def _apply_type_to_mapped_statement( # the original right-hand side is maintained so it gets type checked # internally column_descriptor = nodes.NameExpr("__sa_Mapped") - column_descriptor.fullname = "sqlalchemy.orm.Mapped" + column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" mm = nodes.MemberExpr(column_descriptor, "_empty_constructor") orig_call_expr = stmt.rvalue stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"]) @@ -199,11 +215,11 @@ def _apply_placeholder_attr_to_class( cls: ClassDef, qualified_name: str, attrname: str, -): +) -> None: sym = api.lookup_fully_qualified_or_none(qualified_name) if sym: assert isinstance(sym.node, TypeInfo) - type_: Union[Instance, AnyType] = Instance(sym.node, []) + type_: ProperType = Instance(sym.node, []) else: type_ = AnyType(TypeOfAny.special_form) var = Var(attrname) diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index 40f1f0c0f..8fac36342 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -6,7 +6,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from typing import Optional -from typing import Type +from typing import Union from mypy import nodes from mypy.nodes import AssignmentStmt @@ -14,18 +14,24 @@ from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import Decorator from mypy.nodes import ListExpr +from mypy.nodes import MemberExpr from mypy.nodes import NameExpr from mypy.nodes import PlaceholderNode from mypy.nodes import RefExpr from mypy.nodes import StrExpr +from mypy.nodes import SymbolNode from mypy.nodes import SymbolTableNode from mypy.nodes import TempNode from mypy.nodes import TypeInfo from mypy.nodes import Var from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.types import AnyType +from mypy.types import CallableType +from mypy.types import get_proper_type from mypy.types import Instance from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import Type from mypy.types import TypeOfAny from mypy.types import UnboundType from mypy.types import UnionType @@ -37,7 +43,9 @@ from . import util def _scan_declarative_assignments_and_apply_types( - cls: ClassDef, api: SemanticAnalyzerPluginInterface, is_mixin_scan=False + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + is_mixin_scan: bool = False, ) -> Optional[util.DeclClassApplied]: info = util._info_for_cls(cls, api) @@ -94,16 +102,17 @@ def _scan_symbol_table_entry( name: str, value: SymbolTableNode, cls_metadata: util.DeclClassApplied, -): +) -> None: """Extract mapping information from a SymbolTableNode that's in the type.names dictionary. """ - if not isinstance(value.type, Instance): + value_type = get_proper_type(value.type) + if not isinstance(value_type, Instance): return left_hand_explicit_type = None - type_id = names._type_id_for_named_node(value.type.type) + type_id = names._type_id_for_named_node(value_type.type) # type_id = names._type_id_for_unbound_type(value.type.type, cls, api) err = False @@ -118,22 +127,24 @@ def _scan_symbol_table_entry( names.SYNONYM_PROPERTY, names.COLUMN_PROPERTY, }: - if value.type.args: - left_hand_explicit_type = value.type.args[0] + if value_type.args: + left_hand_explicit_type = get_proper_type(value_type.args[0]) else: err = True elif type_id is names.COLUMN: - if not value.type.args: + if not value_type.args: err = True else: - typeengine_arg = value.type.args[0] + typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type( + value_type.args[0] + ) if isinstance(typeengine_arg, Instance): typeengine_arg = typeengine_arg.type if isinstance(typeengine_arg, (UnboundType, TypeInfo)): sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) - if sym is not None: - if names._mro_has_id(sym.node.mro, names.TYPEENGINE): + if sym is not None and isinstance(sym.node, TypeInfo): + if names._has_base_type_id(sym.node, names.TYPEENGINE): left_hand_explicit_type = UnionType( [ @@ -148,7 +159,7 @@ def _scan_symbol_table_entry( api, "Column type should be a TypeEngine " "subclass not '{}'".format(sym.node.fullname), - value.type, + value_type, ) if err: @@ -158,7 +169,7 @@ def _scan_symbol_table_entry( "one of: Mapped[<python type>], relationship[<target class>], " "Column[<TypeEngine>], MapperProperty[<python type>]" ) - util.fail(api, msg.format(name, cls.name)) + util.fail(api, msg.format(name, cls.name), cls) left_hand_explicit_type = AnyType(TypeOfAny.special_form) @@ -171,7 +182,7 @@ def _scan_declarative_decorator_stmt( api: SemanticAnalyzerPluginInterface, stmt: Decorator, cls_metadata: util.DeclClassApplied, -): +) -> None: """Extract mapping information from a @declared_attr in a declarative class. @@ -195,16 +206,19 @@ def _scan_declarative_decorator_stmt( """ for dec in stmt.decorators: - if names._type_id_for_named_node(dec) is names.DECLARED_ATTR: + if ( + isinstance(dec, (NameExpr, MemberExpr, SymbolNode)) + and names._type_id_for_named_node(dec) is names.DECLARED_ATTR + ): break else: return dec_index = cls.defs.body.index(stmt) - left_hand_explicit_type = None + left_hand_explicit_type: Optional[ProperType] = None - if stmt.func.type is not None: + if isinstance(stmt.func.type, CallableType): func_type = stmt.func.type.ret_type if isinstance(func_type, UnboundType): type_id = names._type_id_for_unbound_type(func_type, cls, api) @@ -225,30 +239,28 @@ def _scan_declarative_decorator_stmt( } and func_type.args ): - left_hand_explicit_type = func_type.args[0] + left_hand_explicit_type = get_proper_type(func_type.args[0]) elif type_id is names.COLUMN and func_type.args: typeengine_arg = func_type.args[0] if isinstance(typeengine_arg, UnboundType): sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) - if sym is not None and names._mro_has_id( - sym.node.mro, names.TYPEENGINE - ): - - left_hand_explicit_type = UnionType( - [ - infer._extract_python_type_from_typeengine( - api, sym.node, [] - ), - NoneType(), - ] - ) - else: - util.fail( - api, - "Column type should be a TypeEngine " - "subclass not '{}'".format(sym.node.fullname), - func_type, - ) + if sym is not None and isinstance(sym.node, TypeInfo): + if names._has_base_type_id(sym.node, names.TYPEENGINE): + left_hand_explicit_type = UnionType( + [ + infer._extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + func_type, + ) if left_hand_explicit_type is None: # no type on the decorated function. our option here is to @@ -274,8 +286,8 @@ def _scan_declarative_decorator_stmt( # of converting it to the regular Instance/TypeInfo/UnionType structures # we see everywhere else. if isinstance(left_hand_explicit_type, UnboundType): - left_hand_explicit_type = util._unbound_to_instance( - api, left_hand_explicit_type + left_hand_explicit_type = get_proper_type( + util._unbound_to_instance(api, left_hand_explicit_type) ) left_node.node.type = api.named_type( @@ -315,7 +327,7 @@ def _scan_declarative_assignment_stmt( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, cls_metadata: util.DeclClassApplied, -): +) -> None: """Extract mapping information from an assignment statement in a declarative class. @@ -339,7 +351,7 @@ def _scan_declarative_assignment_stmt( assert isinstance(node, Var) if node.name == "__abstract__": - if stmt.rvalue.fullname == "builtins.True": + if api.parse_bool(stmt.rvalue) is True: cls_metadata.is_mapped = False return elif node.name == "__tablename__": @@ -354,7 +366,8 @@ def _scan_declarative_assignment_stmt( if isinstance(item, (NameExpr, StrExpr)): apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata) - left_hand_mapped_type: Type = None + left_hand_mapped_type: Optional[Type] = None + left_hand_explicit_type: Optional[ProperType] = None if node.is_inferred or node.type is None: if isinstance(stmt.type, UnboundType): @@ -370,32 +383,33 @@ def _scan_declarative_assignment_stmt( mapped_sym = api.lookup_qualified("Mapped", cls) if ( mapped_sym is not None + and mapped_sym.node is not None and names._type_id_for_named_node(mapped_sym.node) is names.MAPPED ): - left_hand_explicit_type = stmt.type.args[0] + left_hand_explicit_type = get_proper_type( + stmt.type.args[0] + ) left_hand_mapped_type = stmt.type # TODO: do we need to convert from unbound for this case? # left_hand_explicit_type = util._unbound_to_instance( # api, left_hand_explicit_type # ) - - else: - left_hand_explicit_type = None else: + node_type = get_proper_type(node.type) if ( - isinstance(node.type, Instance) - and names._type_id_for_named_node(node.type.type) is names.MAPPED + isinstance(node_type, Instance) + and names._type_id_for_named_node(node_type.type) is names.MAPPED ): # print(node.type) # sqlalchemy.orm.attributes.Mapped[<python type>] - left_hand_explicit_type = node.type.args[0] - left_hand_mapped_type = node.type + left_hand_explicit_type = get_proper_type(node_type.args[0]) + left_hand_mapped_type = node_type else: # print(node.type) # <python type> - left_hand_explicit_type = node.type + left_hand_explicit_type = node_type left_hand_mapped_type = None if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None: @@ -440,10 +454,10 @@ def _scan_declarative_assignment_stmt( else: return - cls_metadata.mapped_attr_names.append((node.name, python_type_for_type)) - assert python_type_for_type is not None + cls_metadata.mapped_attr_names.append((node.name, python_type_for_type)) + apply._apply_type_to_mapped_statement( api, stmt, @@ -485,6 +499,6 @@ def _scan_for_mapped_bases( ) ) - if base_decl_class_applied not in (None, False): + if base_decl_class_applied is not None: cls_metadata.mapped_mro.append(base) baseclasses.extend(base.type.bases) diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index f1bda7865..7915c3ae2 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -6,23 +6,26 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from typing import Optional -from typing import Union +from typing import Sequence -from mypy import nodes -from mypy import types from mypy.maptype import map_instance_to_supertype from mypy.messages import format_type from mypy.nodes import AssignmentStmt from mypy.nodes import CallExpr +from mypy.nodes import Expression +from mypy.nodes import MemberExpr from mypy.nodes import NameExpr +from mypy.nodes import RefExpr from mypy.nodes import StrExpr from mypy.nodes import TypeInfo from mypy.nodes import Var from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.subtypes import is_subtype from mypy.types import AnyType +from mypy.types import get_proper_type from mypy.types import Instance from mypy.types import NoneType +from mypy.types import ProperType from mypy.types import TypeOfAny from mypy.types import UnionType @@ -34,8 +37,8 @@ def _infer_type_from_relationship( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Union[Instance, UnionType, None]: + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: """Infer the type of mapping from a relationship. E.g.:: @@ -62,7 +65,7 @@ def _infer_type_from_relationship( assert isinstance(stmt.rvalue, CallExpr) target_cls_arg = stmt.rvalue.args[0] - python_type_for_type = None + python_type_for_type: Optional[ProperType] = None if isinstance(target_cls_arg, NameExpr) and isinstance( target_cls_arg.node, TypeInfo @@ -86,7 +89,7 @@ def _infer_type_from_relationship( # isinstance(target_cls_arg, StrExpr) uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist") - collection_cls_arg = util._get_callexpr_kwarg( + collection_cls_arg: Optional[Expression] = util._get_callexpr_kwarg( stmt.rvalue, "collection_class" ) type_is_a_collection = False @@ -98,7 +101,7 @@ def _infer_type_from_relationship( if ( uselist_arg is not None - and uselist_arg.fullname == "builtins.True" + and api.parse_bool(uselist_arg) is True and collection_cls_arg is None ): type_is_a_collection = True @@ -107,7 +110,7 @@ def _infer_type_from_relationship( "__builtins__.list", [python_type_for_type] ) elif ( - uselist_arg is None or uselist_arg.fullname == "builtins.True" + uselist_arg is None or api.parse_bool(uselist_arg) is True ) and collection_cls_arg is not None: type_is_a_collection = True if isinstance(collection_cls_arg, CallExpr): @@ -130,7 +133,7 @@ def _infer_type_from_relationship( stmt.rvalue, ) python_type_for_type = None - elif uselist_arg is not None and uselist_arg.fullname == "builtins.False": + elif uselist_arg is not None and api.parse_bool(uselist_arg) is False: if collection_cls_arg is not None: util.fail( api, @@ -159,13 +162,19 @@ def _infer_type_from_relationship( api, node, left_hand_explicit_type ) elif left_hand_explicit_type is not None: - return _infer_type_from_left_and_inferred_right( - api, - node, - left_hand_explicit_type, - python_type_for_type, - type_is_a_collection=type_is_a_collection, - ) + if type_is_a_collection: + assert isinstance(left_hand_explicit_type, Instance) + assert isinstance(python_type_for_type, Instance) + return _infer_collection_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_explicit_type, + python_type_for_type, + ) else: return python_type_for_type @@ -174,8 +183,8 @@ def _infer_type_from_decl_composite_property( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Union[Instance, UnionType, None]: + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: """Infer the type of mapping from a CompositeProperty.""" assert isinstance(stmt.rvalue, CallExpr) @@ -206,8 +215,8 @@ def _infer_type_from_decl_column_property( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Union[Instance, UnionType, None]: + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: """Infer the type of mapping from a ColumnProperty. This includes mappings against ``column_property()`` as well as the @@ -219,28 +228,26 @@ def _infer_type_from_decl_column_property( if isinstance(first_prop_arg, CallExpr): type_id = names._type_id_for_callee(first_prop_arg.callee) - else: - type_id = None - # look for column_property() / deferred() etc with Column as first - # argument - if type_id is names.COLUMN: - return _infer_type_from_decl_column( - api, stmt, node, left_hand_explicit_type, first_prop_arg - ) - else: - return _infer_type_from_left_hand_type_only( - api, node, left_hand_explicit_type - ) + # look for column_property() / deferred() etc with Column as first + # argument + if type_id is names.COLUMN: + return _infer_type_from_decl_column( + api, stmt, node, left_hand_explicit_type, first_prop_arg + ) + + return _infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) def _infer_type_from_decl_column( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, node: Var, - left_hand_explicit_type: Optional[types.Type], + left_hand_explicit_type: Optional[ProperType], right_hand_expression: CallExpr, -) -> Union[Instance, UnionType, None]: +) -> Optional[ProperType]: """Infer the type of mapping from a Column. E.g.:: @@ -277,12 +284,13 @@ def _infer_type_from_decl_column( callee = None for column_arg in right_hand_expression.args[0:2]: - if isinstance(column_arg, nodes.CallExpr): - # x = Column(String(50)) - callee = column_arg.callee - type_args = column_arg.args - break - elif isinstance(column_arg, (nodes.NameExpr, nodes.MemberExpr)): + if isinstance(column_arg, CallExpr): + if isinstance(column_arg.callee, RefExpr): + # x = Column(String(50)) + callee = column_arg.callee + type_args: Sequence[Expression] = column_arg.args + break + elif isinstance(column_arg, (NameExpr, MemberExpr)): if isinstance(column_arg.node, TypeInfo): # x = Column(String) callee = column_arg @@ -314,10 +322,7 @@ def _infer_type_from_decl_column( ) else: - python_type_for_type = UnionType( - [python_type_for_type, NoneType()] - ) - return python_type_for_type + return UnionType([python_type_for_type, NoneType()]) else: # it's not TypeEngine, it's typically implicitly typed # like ForeignKey. we can't infer from the right side. @@ -329,10 +334,11 @@ def _infer_type_from_decl_column( def _infer_type_from_left_and_inferred_right( api: SemanticAnalyzerPluginInterface, node: Var, - left_hand_explicit_type: Optional[types.Type], - python_type_for_type: Union[Instance, UnionType], - type_is_a_collection: bool = False, -) -> Optional[Union[Instance, UnionType]]: + left_hand_explicit_type: ProperType, + python_type_for_type: ProperType, + orig_left_hand_type: Optional[ProperType] = None, + orig_python_type_for_type: Optional[ProperType] = None, +) -> Optional[ProperType]: """Validate type when a left hand annotation is present and we also could infer the right hand side:: @@ -340,12 +346,10 @@ def _infer_type_from_left_and_inferred_right( """ - orig_left_hand_type = left_hand_explicit_type - orig_python_type_for_type = python_type_for_type - - if type_is_a_collection and left_hand_explicit_type.args: - left_hand_explicit_type = left_hand_explicit_type.args[0] - python_type_for_type = python_type_for_type.args[0] + if orig_left_hand_type is None: + orig_left_hand_type = left_hand_explicit_type + if orig_python_type_for_type is None: + orig_python_type_for_type = python_type_for_type if not is_subtype(left_hand_explicit_type, python_type_for_type): effective_type = api.named_type( @@ -369,11 +373,40 @@ def _infer_type_from_left_and_inferred_right( return orig_left_hand_type +def _infer_collection_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Instance, + python_type_for_type: Instance, +) -> Optional[ProperType]: + orig_left_hand_type = left_hand_explicit_type + orig_python_type_for_type = python_type_for_type + + if left_hand_explicit_type.args: + left_hand_arg = get_proper_type(left_hand_explicit_type.args[0]) + python_type_arg = get_proper_type(python_type_for_type.args[0]) + else: + left_hand_arg = left_hand_explicit_type + python_type_arg = python_type_for_type + + assert isinstance(left_hand_arg, (Instance, UnionType)) + assert isinstance(python_type_arg, (Instance, UnionType)) + + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_arg, + python_type_arg, + orig_left_hand_type=orig_left_hand_type, + orig_python_type_for_type=orig_python_type_for_type, + ) + + def _infer_type_from_left_hand_type_only( api: SemanticAnalyzerPluginInterface, node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Optional[Union[Instance, UnionType]]: + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: """Determine the type based on explicit annotation only. if no annotation were present, note that we need one there to know @@ -397,8 +430,10 @@ def _infer_type_from_left_hand_type_only( def _extract_python_type_from_typeengine( - api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args -) -> Instance: + api: SemanticAnalyzerPluginInterface, + node: TypeInfo, + type_args: Sequence[Expression], +) -> ProperType: if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args: first_arg = type_args[0] if isinstance(first_arg, NameExpr) and isinstance( @@ -426,4 +461,4 @@ def _extract_python_type_from_typeengine( Instance(node, []), type_engine_sym.node, ) - return type_engine.args[-1] + return get_proper_type(type_engine.args[-1]) diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py index 174a8f422..6ee600cd7 100644 --- a/lib/sqlalchemy/ext/mypy/names.py +++ b/lib/sqlalchemy/ext/mypy/names.py @@ -5,40 +5,48 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from typing import Dict from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union from mypy.nodes import ClassDef from mypy.nodes import Expression from mypy.nodes import FuncDef -from mypy.nodes import RefExpr +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr from mypy.nodes import SymbolNode from mypy.nodes import TypeAlias from mypy.nodes import TypeInfo -from mypy.nodes import Union from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance from mypy.types import UnboundType from ... import util -COLUMN = util.symbol("COLUMN") -RELATIONSHIP = util.symbol("RELATIONSHIP") -REGISTRY = util.symbol("REGISTRY") -COLUMN_PROPERTY = util.symbol("COLUMN_PROPERTY") -TYPEENGINE = util.symbol("TYPEENGNE") -MAPPED = util.symbol("MAPPED") -DECLARATIVE_BASE = util.symbol("DECLARATIVE_BASE") -DECLARATIVE_META = util.symbol("DECLARATIVE_META") -MAPPED_DECORATOR = util.symbol("MAPPED_DECORATOR") -COLUMN_PROPERTY = util.symbol("COLUMN_PROPERTY") -SYNONYM_PROPERTY = util.symbol("SYNONYM_PROPERTY") -COMPOSITE_PROPERTY = util.symbol("COMPOSITE_PROPERTY") -DECLARED_ATTR = util.symbol("DECLARED_ATTR") -MAPPER_PROPERTY = util.symbol("MAPPER_PROPERTY") -AS_DECLARATIVE = util.symbol("AS_DECLARATIVE") -AS_DECLARATIVE_BASE = util.symbol("AS_DECLARATIVE_BASE") -DECLARATIVE_MIXIN = util.symbol("DECLARATIVE_MIXIN") - -_lookup = { +COLUMN: int = util.symbol("COLUMN") # type: ignore +RELATIONSHIP: int = util.symbol("RELATIONSHIP") # type: ignore +REGISTRY: int = util.symbol("REGISTRY") # type: ignore +COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore +TYPEENGINE: int = util.symbol("TYPEENGNE") # type: ignore +MAPPED: int = util.symbol("MAPPED") # type: ignore +DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") # type: ignore +DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") # type: ignore +MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") # type: ignore +COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore +SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") # type: ignore +COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") # type: ignore +DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") # type: ignore +MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") # type: ignore +AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") # type: ignore +AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") # type: ignore +DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") # type: ignore + +_lookup: Dict[str, Tuple[int, Set[str]]] = { "Column": ( COLUMN, { @@ -145,7 +153,21 @@ _lookup = { } -def _mro_has_id(mro: List[TypeInfo], type_id: int): +def _has_base_type_id(info: TypeInfo, type_id: int) -> bool: + for mr in info.mro: + check_type_id, fullnames = _lookup.get(mr.name, (None, None)) + if check_type_id == type_id: + break + else: + return False + + if fullnames is None: + return False + + return mr.fullname in fullnames + + +def _mro_has_id(mro: List[TypeInfo], type_id: int) -> bool: for mr in mro: check_type_id, fullnames = _lookup.get(mr.name, (None, None)) if check_type_id == type_id: @@ -153,65 +175,75 @@ def _mro_has_id(mro: List[TypeInfo], type_id: int): else: return False + if fullnames is None: + return False + return mr.fullname in fullnames def _type_id_for_unbound_type( type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface -) -> int: +) -> Optional[int]: type_id = None sym = api.lookup_qualified(type_.name, type_) if sym is not None: if isinstance(sym.node, TypeAlias): - type_id = _type_id_for_named_node(sym.node.target.type) + target_type = get_proper_type(sym.node.target) + if isinstance(target_type, Instance): + type_id = _type_id_for_named_node(target_type.type) elif isinstance(sym.node, TypeInfo): type_id = _type_id_for_named_node(sym.node) return type_id -def _type_id_for_callee(callee: Expression) -> int: - if isinstance(callee.node, FuncDef): - return _type_id_for_funcdef(callee.node) - elif isinstance(callee.node, TypeAlias): - type_id = _type_id_for_fullname(callee.node.target.type.fullname) - elif isinstance(callee.node, TypeInfo): - type_id = _type_id_for_named_node(callee) - else: - type_id = None +def _type_id_for_callee(callee: Expression) -> Optional[int]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, FuncDef): + return _type_id_for_funcdef(callee.node) + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + type_id = _type_id_for_fullname(target_type.type.fullname) + elif isinstance(callee.node, TypeInfo): + type_id = _type_id_for_named_node(callee) + else: + type_id = None return type_id -def _type_id_for_funcdef(node: FuncDef) -> int: - if hasattr(node.type.ret_type, "type"): - type_id = _type_id_for_fullname(node.type.ret_type.type.fullname) - else: - type_id = None - return type_id +def _type_id_for_funcdef(node: FuncDef) -> Optional[int]: + if node.type and isinstance(node.type, CallableType): + ret_type = get_proper_type(node.type.ret_type) + + if isinstance(ret_type, Instance): + return _type_id_for_fullname(ret_type.type.fullname) + + return None -def _type_id_for_named_node(node: Union[RefExpr, SymbolNode]) -> int: +def _type_id_for_named_node( + node: Union[NameExpr, MemberExpr, SymbolNode] +) -> Optional[int]: type_id, fullnames = _lookup.get(node.name, (None, None)) - if type_id is None: + if type_id is None or fullnames is None: return None - elif node.fullname in fullnames: return type_id else: return None -def _type_id_for_fullname(fullname: str) -> int: +def _type_id_for_fullname(fullname: str) -> Optional[int]: tokens = fullname.split(".") immediate = tokens[-1] type_id, fullnames = _lookup.get(immediate, (None, None)) - if type_id is None: + if type_id is None or fullnames is None: return None - elif fullname in fullnames: return type_id else: diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py index 23585be49..76aac5152 100644 --- a/lib/sqlalchemy/ext/mypy/plugin.py +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -9,9 +9,12 @@ Mypy plugin for SQLAlchemy ORM. """ +from typing import Callable from typing import List +from typing import Optional from typing import Tuple -from typing import Type +from typing import Type as TypingType +from typing import Union from mypy import nodes from mypy.mro import calculate_mro @@ -25,20 +28,20 @@ from mypy.nodes import SymbolTable from mypy.nodes import SymbolTableNode from mypy.nodes import TypeInfo from mypy.plugin import AttributeContext -from mypy.plugin import Callable from mypy.plugin import ClassDefContext from mypy.plugin import DynamicClassDefContext -from mypy.plugin import Optional from mypy.plugin import Plugin from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import get_proper_type from mypy.types import Instance +from mypy.types import Type from . import decl_class from . import names from . import util -class CustomPlugin(Plugin): +class SQLAlchemyPlugin(Plugin): def get_dynamic_class_hook( self, fullname: str ) -> Optional[Callable[[DynamicClassDefContext], None]]: @@ -72,7 +75,7 @@ class CustomPlugin(Plugin): sym = self.lookup_fully_qualified(fullname) - if sym is not None: + if sym is not None and sym.node is not None: type_id = names._type_id_for_named_node(sym.node) if type_id is names.MAPPED_DECORATOR: return _cls_decorator_hook @@ -109,8 +112,8 @@ class CustomPlugin(Plugin): ] -def plugin(version: str): - return CustomPlugin +def plugin(version: str) -> TypingType[SQLAlchemyPlugin]: + return SQLAlchemyPlugin def _queryable_getattr_hook(ctx: AttributeContext) -> Type: @@ -143,14 +146,14 @@ def _fill_in_decorators(ctx: ClassDefContext) -> None: else: continue + assert isinstance(target.expr, NameExpr) sym = ctx.api.lookup_qualified( target.expr.name, target, suppress_errors=True ) - if sym: - if sym.node.type and hasattr(sym.node.type, "type"): - target.fullname = ( - f"{sym.node.type.type.fullname}.{target.name}" - ) + if sym and sym.node: + sym_type = get_proper_type(sym.type) + if isinstance(sym_type, Instance): + target.fullname = f"{sym_type.type.fullname}.{target.name}" else: # if the registry is in the same file as where the # decorator is used, it might not have semantic @@ -170,7 +173,7 @@ def _fill_in_decorators(ctx: ClassDefContext) -> None: ) -def _add_globals(ctx: ClassDefContext): +def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None: """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space for all class defs @@ -207,7 +210,15 @@ def _cls_decorator_hook(ctx: ClassDefContext) -> None: _add_globals(ctx) assert isinstance(ctx.reason, nodes.MemberExpr) expr = ctx.reason.expr - assert names._type_id_for_named_node(expr.node.type.type) is names.REGISTRY + + assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var) + + node_type = get_proper_type(expr.node.type) + + assert ( + isinstance(node_type, Instance) + and names._type_id_for_named_node(node_type.type) is names.REGISTRY + ) decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) @@ -237,8 +248,8 @@ def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: cls.info = info _make_declarative_meta(ctx.api, cls) - cls_arg = util._get_callexpr_kwarg(ctx.call, "cls") - if cls_arg is not None: + cls_arg = util._get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,)) + if cls_arg is not None and isinstance(cls_arg.node, TypeInfo): decl_class._scan_declarative_assignments_and_apply_types( cls_arg.node.defn, ctx.api, is_mixin_scan=True ) @@ -263,7 +274,7 @@ def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: def _make_declarative_meta( api: SemanticAnalyzerPluginInterface, target_cls: ClassDef -): +) -> None: declarative_meta_name: NameExpr = NameExpr("__sa_DeclarativeMeta") declarative_meta_name.kind = GDEF @@ -272,6 +283,8 @@ def _make_declarative_meta( # installed by _add_globals sym = api.lookup_qualified("__sa_DeclarativeMeta", target_cls) + assert sym is not None and isinstance(sym.node, nodes.TypeInfo) + declarative_meta_typeinfo = sym.node declarative_meta_name.node = declarative_meta_typeinfo diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 1c1e56d2c..26bb0ac67 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -1,37 +1,52 @@ +from typing import Any +from typing import cast +from typing import Iterable +from typing import Iterator +from typing import List from typing import Optional -from typing import Sequence +from typing import overload from typing import Tuple -from typing import Type +from typing import Type as TypingType +from typing import TypeVar +from typing import Union from mypy.nodes import CallExpr +from mypy.nodes import ClassDef from mypy.nodes import CLASSDEF_NO_INFO from mypy.nodes import Context from mypy.nodes import IfStmt from mypy.nodes import JsonDict from mypy.nodes import NameExpr +from mypy.nodes import Statement from mypy.nodes import SymbolTableNode from mypy.nodes import TypeInfo from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import deserialize_and_fixup_type from mypy.types import Instance from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import Type from mypy.types import UnboundType from mypy.types import UnionType +_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr]) + + class DeclClassApplied: def __init__( self, is_mapped: bool, has_table: bool, - mapped_attr_names: Sequence[Tuple[str, Type]], - mapped_mro: Sequence[Type], + mapped_attr_names: Iterable[Tuple[str, ProperType]], + mapped_mro: Iterable[Instance], ): self.is_mapped = is_mapped self.has_table = has_table - self.mapped_attr_names = mapped_attr_names - self.mapped_mro = mapped_mro + self.mapped_attr_names = list(mapped_attr_names) + self.mapped_mro = list(mapped_mro) def serialize(self) -> JsonDict: return { @@ -52,28 +67,34 @@ class DeclClassApplied: return DeclClassApplied( is_mapped=data["is_mapped"], has_table=data["has_table"], - mapped_attr_names=[ - (name, deserialize_and_fixup_type(type_, api)) - for name, type_ in data["mapped_attr_names"] - ], - mapped_mro=[ - deserialize_and_fixup_type(type_, api) - for type_ in data["mapped_mro"] - ], + mapped_attr_names=cast( + List[Tuple[str, ProperType]], + [ + (name, deserialize_and_fixup_type(type_, api)) + for name, type_ in data["mapped_attr_names"] + ], + ), + mapped_mro=cast( + List[Instance], + [ + deserialize_and_fixup_type(type_, api) + for type_ in data["mapped_mro"] + ], + ), ) -def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context): +def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None: msg = "[SQLAlchemy Mypy plugin] %s" % msg return api.fail(msg, ctx) def add_global( - ctx: ClassDefContext, + ctx: Union[ClassDefContext, DynamicClassDefContext], module: str, symbol_name: str, asname: str, -): +) -> None: module_globals = ctx.api.modules[ctx.api.cur_mod_id].names if asname not in module_globals: @@ -84,18 +105,50 @@ def add_global( module_globals[asname] = lookup_sym -def _get_callexpr_kwarg(callexpr: CallExpr, name: str) -> Optional[NameExpr]: +@overload +def _get_callexpr_kwarg( + callexpr: CallExpr, name: str, *, expr_types: None = ... +) -> Optional[Union[CallExpr, NameExpr]]: + ... + + +@overload +def _get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Tuple[TypingType[_TArgType], ...] +) -> Optional[_TArgType]: + ... + + +def _get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Optional[Tuple[TypingType[Any], ...]] = None +) -> Optional[Any]: try: arg_idx = callexpr.arg_names.index(name) except ValueError: return None - return callexpr.args[arg_idx] + kwarg = callexpr.args[arg_idx] + if isinstance( + kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr) + ): + return kwarg + return None -def _flatten_typechecking(stmts): + +def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: for stmt in stmts: - if isinstance(stmt, IfStmt) and stmt.expr[0].name == "TYPE_CHECKING": + if ( + isinstance(stmt, IfStmt) + and isinstance(stmt.expr[0], NameExpr) + and stmt.expr[0].fullname == "typing.TYPE_CHECKING" + ): for substmt in stmt.body[0].body: yield substmt else: @@ -103,7 +156,7 @@ def _flatten_typechecking(stmts): def _unbound_to_instance( - api: SemanticAnalyzerPluginInterface, typ: UnboundType + api: SemanticAnalyzerPluginInterface, typ: Type ) -> Type: """Take the UnboundType that we seem to get as the ret_type from a FuncDef and convert it into an Instance/TypeInfo kind of structure that seems @@ -130,7 +183,11 @@ def _unbound_to_instance( node = api.lookup_qualified(typ.name, typ) - if node is not None and isinstance(node, SymbolTableNode): + if ( + node is not None + and isinstance(node, SymbolTableNode) + and isinstance(node.node, TypeInfo) + ): bound_type = node.node return Instance( @@ -146,12 +203,12 @@ def _unbound_to_instance( return typ -def _info_for_cls(cls, api): +def _info_for_cls( + cls: ClassDef, api: SemanticAnalyzerPluginInterface +) -> TypeInfo: if cls.info is CLASSDEF_NO_INFO: sym = api.lookup_qualified(cls.name, cls) - if sym.node and isinstance(sym.node, TypeInfo): - info = sym.node - else: - info = cls.info + assert sym and isinstance(sym.node, TypeInfo) + return sym.node - return info + return cls.info |