diff options
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/util.py')
-rw-r--r-- | lib/sqlalchemy/ext/mypy/util.py | 62 |
1 files changed, 61 insertions, 1 deletions
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 7079f3cd7..becce3ebe 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -1,18 +1,67 @@ from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from mypy.nodes import CallExpr +from mypy.nodes import CLASSDEF_NO_INFO from mypy.nodes import Context from mypy.nodes import IfStmt +from mypy.nodes import JsonDict from mypy.nodes import NameExpr from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeInfo from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import deserialize_and_fixup_type from mypy.types import Instance from mypy.types import NoneType -from mypy.types import Type from mypy.types import UnboundType from mypy.types import UnionType +class DeclClassApplied: + def __init__( + self, + is_mapped: bool, + has_table: bool, + mapped_attr_names: Sequence[Tuple[str, Type]], + mapped_mro: Sequence[Type], + ): + self.is_mapped = is_mapped + self.has_table = has_table + self.mapped_attr_names = mapped_attr_names + self.mapped_mro = mapped_mro + + def serialize(self) -> JsonDict: + return { + "is_mapped": self.is_mapped, + "has_table": self.has_table, + "mapped_attr_names": [ + (name, type_.serialize()) + for name, type_ in self.mapped_attr_names + ], + "mapped_mro": [type_.serialize() for type_ in self.mapped_mro], + } + + @classmethod + def deserialize( + cls, data: JsonDict, api: SemanticAnalyzerPluginInterface + ) -> "DeclClassApplied": + + return DeclClassApplied( + is_mapped=data["is_mapped"], + has_table=data["has_table"], + mapped_attr_names=[ + (name, deserialize_and_fixup_type(type_, api)) + for name, type_ in data["mapped_attr_names"] + ], + mapped_mro=[ + deserialize_and_fixup_type(type_, api) + for type_ in data["mapped_mro"] + ], + ) + + def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context): msg = "[SQLAlchemy Mypy plugin] %s" % msg return api.fail(msg, ctx) @@ -94,3 +143,14 @@ def _unbound_to_instance( ) else: return typ + + +def _info_for_cls(cls, api): + if cls.info is CLASSDEF_NO_INFO: + sym = api.lookup(cls.name, cls) + if sym.node and isinstance(sym.node, TypeInfo): + info = sym.node + else: + info = cls.info + + return info |