summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/util.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/util.py62
1 files changed, 61 insertions, 1 deletions
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
index 7079f3cd7..becce3ebe 100644
--- a/lib/sqlalchemy/ext/mypy/util.py
+++ b/lib/sqlalchemy/ext/mypy/util.py
@@ -1,18 +1,67 @@
from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
from mypy.nodes import CallExpr
+from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
from mypy.nodes import IfStmt
+from mypy.nodes import JsonDict
from mypy.nodes import NameExpr
from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeInfo
from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import deserialize_and_fixup_type
from mypy.types import Instance
from mypy.types import NoneType
-from mypy.types import Type
from mypy.types import UnboundType
from mypy.types import UnionType
+class DeclClassApplied:
+ def __init__(
+ self,
+ is_mapped: bool,
+ has_table: bool,
+ mapped_attr_names: Sequence[Tuple[str, Type]],
+ mapped_mro: Sequence[Type],
+ ):
+ self.is_mapped = is_mapped
+ self.has_table = has_table
+ self.mapped_attr_names = mapped_attr_names
+ self.mapped_mro = mapped_mro
+
+ def serialize(self) -> JsonDict:
+ return {
+ "is_mapped": self.is_mapped,
+ "has_table": self.has_table,
+ "mapped_attr_names": [
+ (name, type_.serialize())
+ for name, type_ in self.mapped_attr_names
+ ],
+ "mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
+ }
+
+ @classmethod
+ def deserialize(
+ cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
+ ) -> "DeclClassApplied":
+
+ return DeclClassApplied(
+ is_mapped=data["is_mapped"],
+ has_table=data["has_table"],
+ mapped_attr_names=[
+ (name, deserialize_and_fixup_type(type_, api))
+ for name, type_ in data["mapped_attr_names"]
+ ],
+ mapped_mro=[
+ deserialize_and_fixup_type(type_, api)
+ for type_ in data["mapped_mro"]
+ ],
+ )
+
+
def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context):
msg = "[SQLAlchemy Mypy plugin] %s" % msg
return api.fail(msg, ctx)
@@ -94,3 +143,14 @@ def _unbound_to_instance(
)
else:
return typ
+
+
+def _info_for_cls(cls, api):
+ if cls.info is CLASSDEF_NO_INFO:
+ sym = api.lookup(cls.name, cls)
+ if sym.node and isinstance(sym.node, TypeInfo):
+ info = sym.node
+ else:
+ info = cls.info
+
+ return info