summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/util.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/util.py115
1 files changed, 86 insertions, 29 deletions
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
index 1c1e56d2c..26bb0ac67 100644
--- a/lib/sqlalchemy/ext/mypy/util.py
+++ b/lib/sqlalchemy/ext/mypy/util.py
@@ -1,37 +1,52 @@
+from typing import Any
+from typing import cast
+from typing import Iterable
+from typing import Iterator
+from typing import List
from typing import Optional
-from typing import Sequence
+from typing import overload
from typing import Tuple
-from typing import Type
+from typing import Type as TypingType
+from typing import TypeVar
+from typing import Union
from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
from mypy.nodes import IfStmt
from mypy.nodes import JsonDict
from mypy.nodes import NameExpr
+from mypy.nodes import Statement
from mypy.nodes import SymbolTableNode
from mypy.nodes import TypeInfo
from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import deserialize_and_fixup_type
from mypy.types import Instance
from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import Type
from mypy.types import UnboundType
from mypy.types import UnionType
+_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
+
+
class DeclClassApplied:
def __init__(
self,
is_mapped: bool,
has_table: bool,
- mapped_attr_names: Sequence[Tuple[str, Type]],
- mapped_mro: Sequence[Type],
+ mapped_attr_names: Iterable[Tuple[str, ProperType]],
+ mapped_mro: Iterable[Instance],
):
self.is_mapped = is_mapped
self.has_table = has_table
- self.mapped_attr_names = mapped_attr_names
- self.mapped_mro = mapped_mro
+ self.mapped_attr_names = list(mapped_attr_names)
+ self.mapped_mro = list(mapped_mro)
def serialize(self) -> JsonDict:
return {
@@ -52,28 +67,34 @@ class DeclClassApplied:
return DeclClassApplied(
is_mapped=data["is_mapped"],
has_table=data["has_table"],
- mapped_attr_names=[
- (name, deserialize_and_fixup_type(type_, api))
- for name, type_ in data["mapped_attr_names"]
- ],
- mapped_mro=[
- deserialize_and_fixup_type(type_, api)
- for type_ in data["mapped_mro"]
- ],
+ mapped_attr_names=cast(
+ List[Tuple[str, ProperType]],
+ [
+ (name, deserialize_and_fixup_type(type_, api))
+ for name, type_ in data["mapped_attr_names"]
+ ],
+ ),
+ mapped_mro=cast(
+ List[Instance],
+ [
+ deserialize_and_fixup_type(type_, api)
+ for type_ in data["mapped_mro"]
+ ],
+ ),
)
-def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context):
+def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
msg = "[SQLAlchemy Mypy plugin] %s" % msg
return api.fail(msg, ctx)
def add_global(
- ctx: ClassDefContext,
+ ctx: Union[ClassDefContext, DynamicClassDefContext],
module: str,
symbol_name: str,
asname: str,
-):
+) -> None:
module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
if asname not in module_globals:
@@ -84,18 +105,50 @@ def add_global(
module_globals[asname] = lookup_sym
-def _get_callexpr_kwarg(callexpr: CallExpr, name: str) -> Optional[NameExpr]:
+@overload
+def _get_callexpr_kwarg(
+ callexpr: CallExpr, name: str, *, expr_types: None = ...
+) -> Optional[Union[CallExpr, NameExpr]]:
+ ...
+
+
+@overload
+def _get_callexpr_kwarg(
+ callexpr: CallExpr,
+ name: str,
+ *,
+ expr_types: Tuple[TypingType[_TArgType], ...]
+) -> Optional[_TArgType]:
+ ...
+
+
+def _get_callexpr_kwarg(
+ callexpr: CallExpr,
+ name: str,
+ *,
+ expr_types: Optional[Tuple[TypingType[Any], ...]] = None
+) -> Optional[Any]:
try:
arg_idx = callexpr.arg_names.index(name)
except ValueError:
return None
- return callexpr.args[arg_idx]
+ kwarg = callexpr.args[arg_idx]
+ if isinstance(
+ kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
+ ):
+ return kwarg
+ return None
-def _flatten_typechecking(stmts):
+
+def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
for stmt in stmts:
- if isinstance(stmt, IfStmt) and stmt.expr[0].name == "TYPE_CHECKING":
+ if (
+ isinstance(stmt, IfStmt)
+ and isinstance(stmt.expr[0], NameExpr)
+ and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
+ ):
for substmt in stmt.body[0].body:
yield substmt
else:
@@ -103,7 +156,7 @@ def _flatten_typechecking(stmts):
def _unbound_to_instance(
- api: SemanticAnalyzerPluginInterface, typ: UnboundType
+ api: SemanticAnalyzerPluginInterface, typ: Type
) -> Type:
"""Take the UnboundType that we seem to get as the ret_type from a FuncDef
and convert it into an Instance/TypeInfo kind of structure that seems
@@ -130,7 +183,11 @@ def _unbound_to_instance(
node = api.lookup_qualified(typ.name, typ)
- if node is not None and isinstance(node, SymbolTableNode):
+ if (
+ node is not None
+ and isinstance(node, SymbolTableNode)
+ and isinstance(node.node, TypeInfo)
+ ):
bound_type = node.node
return Instance(
@@ -146,12 +203,12 @@ def _unbound_to_instance(
return typ
-def _info_for_cls(cls, api):
+def _info_for_cls(
+ cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> TypeInfo:
if cls.info is CLASSDEF_NO_INFO:
sym = api.lookup_qualified(cls.name, cls)
- if sym.node and isinstance(sym.node, TypeInfo):
- info = sym.node
- else:
- info = cls.info
+ assert sym and isinstance(sym.node, TypeInfo)
+ return sym.node
- return info
+ return cls.info