diff options
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/util.py')
-rw-r--r-- | lib/sqlalchemy/ext/mypy/util.py | 36 |
1 files changed, 23 insertions, 13 deletions
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index fa42074c3..741772eac 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -10,24 +10,27 @@ from typing import Type as TypingType from typing import TypeVar from typing import Union -from mypy.nodes import ARG_POS from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import CLASSDEF_NO_INFO from mypy.nodes import Context from mypy.nodes import Expression +from mypy.nodes import FuncDef from mypy.nodes import IfStmt from mypy.nodes import JsonDict from mypy.nodes import MemberExpr from mypy.nodes import NameExpr from mypy.nodes import Statement from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeAlias from mypy.nodes import TypeInfo from mypy.plugin import ClassDefContext from mypy.plugin import DynamicClassDefContext from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import deserialize_and_fixup_type from mypy.typeops import map_type_from_supertype +from mypy.types import CallableType +from mypy.types import get_proper_type from mypy.types import Instance from mypy.types import NoneType from mypy.types import Type @@ -231,6 +234,25 @@ def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: yield stmt +def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return ret_type + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return target_type + elif isinstance(callee.node, TypeInfo): + return callee.node + return None + + def unbound_to_instance( api: SemanticAnalyzerPluginInterface, typ: Type ) -> Type: @@ -290,15 +312,3 @@ def info_for_cls( return sym.node return cls.info - - -def expr_to_mapped_constructor(expr: Expression) -> CallExpr: - column_descriptor = NameExpr("__sa_Mapped") - column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" - member_expr = MemberExpr(column_descriptor, "_empty_constructor") - return CallExpr( - member_expr, - [expr], - [ARG_POS], - ["arg1"], - ) |