diff options
author | Bryan Forbes <bryan@reigndropsfall.net> | 2021-04-12 16:24:37 -0500 |
---|---|---|
committer | Bryan Forbes <bryan@reigndropsfall.net> | 2021-04-12 16:24:37 -0500 |
commit | e2008b5541cc155aea538317805e62ff1aa9b300 (patch) | |
tree | 04608c82131e8bb3aa2ada56c5e78d4e0a8936d5 /lib/sqlalchemy/ext/mypy/util.py | |
parent | de7f14104d5278987fa72d6866fa39569e56077e (diff) | |
download | sqlalchemy-e2008b5541cc155aea538317805e62ff1aa9b300.tar.gz |
Update mypy plugin to conform to strict mode
Change-Id: I09a3df5af2f2d4ee34d8d72c3dedc4f236df8eb1
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/util.py')
-rw-r--r-- | lib/sqlalchemy/ext/mypy/util.py | 115 |
1 files changed, 86 insertions, 29 deletions
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 1c1e56d2c..26bb0ac67 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -1,37 +1,52 @@ +from typing import Any +from typing import cast +from typing import Iterable +from typing import Iterator +from typing import List from typing import Optional -from typing import Sequence +from typing import overload from typing import Tuple -from typing import Type +from typing import Type as TypingType +from typing import TypeVar +from typing import Union from mypy.nodes import CallExpr +from mypy.nodes import ClassDef from mypy.nodes import CLASSDEF_NO_INFO from mypy.nodes import Context from mypy.nodes import IfStmt from mypy.nodes import JsonDict from mypy.nodes import NameExpr +from mypy.nodes import Statement from mypy.nodes import SymbolTableNode from mypy.nodes import TypeInfo from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import deserialize_and_fixup_type from mypy.types import Instance from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import Type from mypy.types import UnboundType from mypy.types import UnionType +_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr]) + + class DeclClassApplied: def __init__( self, is_mapped: bool, has_table: bool, - mapped_attr_names: Sequence[Tuple[str, Type]], - mapped_mro: Sequence[Type], + mapped_attr_names: Iterable[Tuple[str, ProperType]], + mapped_mro: Iterable[Instance], ): self.is_mapped = is_mapped self.has_table = has_table - self.mapped_attr_names = mapped_attr_names - self.mapped_mro = mapped_mro + self.mapped_attr_names = list(mapped_attr_names) + self.mapped_mro = list(mapped_mro) def serialize(self) -> JsonDict: return { @@ -52,28 +67,34 @@ class DeclClassApplied: return DeclClassApplied( is_mapped=data["is_mapped"], has_table=data["has_table"], - mapped_attr_names=[ - (name, deserialize_and_fixup_type(type_, api)) - for name, type_ in data["mapped_attr_names"] - ], - mapped_mro=[ - deserialize_and_fixup_type(type_, api) - for type_ in data["mapped_mro"] - ], + mapped_attr_names=cast( + List[Tuple[str, ProperType]], + [ + (name, deserialize_and_fixup_type(type_, api)) + for name, type_ in data["mapped_attr_names"] + ], + ), + mapped_mro=cast( + List[Instance], + [ + deserialize_and_fixup_type(type_, api) + for type_ in data["mapped_mro"] + ], + ), ) -def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context): +def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None: msg = "[SQLAlchemy Mypy plugin] %s" % msg return api.fail(msg, ctx) def add_global( - ctx: ClassDefContext, + ctx: Union[ClassDefContext, DynamicClassDefContext], module: str, symbol_name: str, asname: str, -): +) -> None: module_globals = ctx.api.modules[ctx.api.cur_mod_id].names if asname not in module_globals: @@ -84,18 +105,50 @@ def add_global( module_globals[asname] = lookup_sym -def _get_callexpr_kwarg(callexpr: CallExpr, name: str) -> Optional[NameExpr]: +@overload +def _get_callexpr_kwarg( + callexpr: CallExpr, name: str, *, expr_types: None = ... +) -> Optional[Union[CallExpr, NameExpr]]: + ... + + +@overload +def _get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Tuple[TypingType[_TArgType], ...] +) -> Optional[_TArgType]: + ... + + +def _get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Optional[Tuple[TypingType[Any], ...]] = None +) -> Optional[Any]: try: arg_idx = callexpr.arg_names.index(name) except ValueError: return None - return callexpr.args[arg_idx] + kwarg = callexpr.args[arg_idx] + if isinstance( + kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr) + ): + return kwarg + return None -def _flatten_typechecking(stmts): + +def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: for stmt in stmts: - if isinstance(stmt, IfStmt) and stmt.expr[0].name == "TYPE_CHECKING": + if ( + isinstance(stmt, IfStmt) + and isinstance(stmt.expr[0], NameExpr) + and stmt.expr[0].fullname == "typing.TYPE_CHECKING" + ): for substmt in stmt.body[0].body: yield substmt else: @@ -103,7 +156,7 @@ def _flatten_typechecking(stmts): def _unbound_to_instance( - api: SemanticAnalyzerPluginInterface, typ: UnboundType + api: SemanticAnalyzerPluginInterface, typ: Type ) -> Type: """Take the UnboundType that we seem to get as the ret_type from a FuncDef and convert it into an Instance/TypeInfo kind of structure that seems @@ -130,7 +183,11 @@ def _unbound_to_instance( node = api.lookup_qualified(typ.name, typ) - if node is not None and isinstance(node, SymbolTableNode): + if ( + node is not None + and isinstance(node, SymbolTableNode) + and isinstance(node.node, TypeInfo) + ): bound_type = node.node return Instance( @@ -146,12 +203,12 @@ def _unbound_to_instance( return typ -def _info_for_cls(cls, api): +def _info_for_cls( + cls: ClassDef, api: SemanticAnalyzerPluginInterface +) -> TypeInfo: if cls.info is CLASSDEF_NO_INFO: sym = api.lookup_qualified(cls.name, cls) - if sym.node and isinstance(sym.node, TypeInfo): - info = sym.node - else: - info = cls.info + assert sym and isinstance(sym.node, TypeInfo) + return sym.node - return info + return cls.info |