summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/util.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/util.py36
1 files changed, 23 insertions, 13 deletions
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
index fa42074c3..741772eac 100644
--- a/lib/sqlalchemy/ext/mypy/util.py
+++ b/lib/sqlalchemy/ext/mypy/util.py
@@ -10,24 +10,27 @@ from typing import Type as TypingType
from typing import TypeVar
from typing import Union
-from mypy.nodes import ARG_POS
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
from mypy.nodes import Expression
+from mypy.nodes import FuncDef
from mypy.nodes import IfStmt
from mypy.nodes import JsonDict
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import Statement
from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeAlias
from mypy.nodes import TypeInfo
from mypy.plugin import ClassDefContext
from mypy.plugin import DynamicClassDefContext
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import deserialize_and_fixup_type
from mypy.typeops import map_type_from_supertype
+from mypy.types import CallableType
+from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import Type
@@ -231,6 +234,25 @@ def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
yield stmt
+def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]:
+ if isinstance(callee, (MemberExpr, NameExpr)):
+ if isinstance(callee.node, FuncDef):
+ if callee.node.type and isinstance(callee.node.type, CallableType):
+ ret_type = get_proper_type(callee.node.type.ret_type)
+
+ if isinstance(ret_type, Instance):
+ return ret_type
+
+ return None
+ elif isinstance(callee.node, TypeAlias):
+ target_type = get_proper_type(callee.node.target)
+ if isinstance(target_type, Instance):
+ return target_type
+ elif isinstance(callee.node, TypeInfo):
+ return callee.node
+ return None
+
+
def unbound_to_instance(
api: SemanticAnalyzerPluginInterface, typ: Type
) -> Type:
@@ -290,15 +312,3 @@ def info_for_cls(
return sym.node
return cls.info
-
-
-def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
- column_descriptor = NameExpr("__sa_Mapped")
- column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
- member_expr = MemberExpr(column_descriptor, "_empty_constructor")
- return CallExpr(
- member_expr,
- [expr],
- [ARG_POS],
- ["arg1"],
- )