diff options
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/names.py')
-rw-r--r-- | lib/sqlalchemy/ext/mypy/names.py | 54 |
1 files changed, 42 insertions, 12 deletions
diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py index b6f911979..ad4449e5b 100644 --- a/lib/sqlalchemy/ext/mypy/names.py +++ b/lib/sqlalchemy/ext/mypy/names.py @@ -12,11 +12,14 @@ from typing import Set from typing import Tuple from typing import Union +from mypy.nodes import ARG_POS +from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import Expression from mypy.nodes import FuncDef from mypy.nodes import MemberExpr from mypy.nodes import NameExpr +from mypy.nodes import OverloadedFuncDef from mypy.nodes import SymbolNode from mypy.nodes import TypeAlias from mypy.nodes import TypeInfo @@ -51,7 +54,7 @@ QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore NAMED_TYPE_BUILTINS_OBJECT = "builtins.object" NAMED_TYPE_BUILTINS_STR = "builtins.str" NAMED_TYPE_BUILTINS_LIST = "builtins.list" -NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped" +NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped" _lookup: Dict[str, Tuple[int, Set[str]]] = { "Column": ( @@ -61,11 +64,11 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { "sqlalchemy.sql.Column", }, ), - "RelationshipProperty": ( + "Relationship": ( RELATIONSHIP, { - "sqlalchemy.orm.relationships.RelationshipProperty", - "sqlalchemy.orm.RelationshipProperty", + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.Relationship", }, ), "registry": ( @@ -82,18 +85,18 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { "sqlalchemy.orm.ColumnProperty", }, ), - "SynonymProperty": ( + "Synonym": ( SYNONYM_PROPERTY, { - "sqlalchemy.orm.descriptor_props.SynonymProperty", - "sqlalchemy.orm.SynonymProperty", + "sqlalchemy.orm.descriptor_props.Synonym", + "sqlalchemy.orm.Synonym", }, ), - "CompositeProperty": ( + "Composite": ( COMPOSITE_PROPERTY, { - "sqlalchemy.orm.descriptor_props.CompositeProperty", - "sqlalchemy.orm.CompositeProperty", + "sqlalchemy.orm.descriptor_props.Composite", + "sqlalchemy.orm.Composite", }, ), "MapperProperty": ( @@ -159,7 +162,10 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { ), "query_expression": ( QUERY_EXPRESSION, - {"sqlalchemy.orm.query_expression"}, + { + "sqlalchemy.orm.query_expression", + "sqlalchemy.orm._orm_constructors.query_expression", + }, ), } @@ -209,7 +215,19 @@ def type_id_for_unbound_type( def type_id_for_callee(callee: Expression) -> Optional[int]: if isinstance(callee, (MemberExpr, NameExpr)): - if isinstance(callee.node, FuncDef): + if isinstance(callee.node, OverloadedFuncDef): + if ( + callee.node.impl + and callee.node.impl.type + and isinstance(callee.node.impl.type, CallableType) + ): + ret_type = get_proper_type(callee.node.impl.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, FuncDef): if callee.node.type and isinstance(callee.node.type, CallableType): ret_type = get_proper_type(callee.node.type.ret_type) @@ -251,3 +269,15 @@ def type_id_for_fullname(fullname: str) -> Optional[int]: return type_id else: return None + + +def expr_to_mapped_constructor(expr: Expression) -> CallExpr: + column_descriptor = NameExpr("__sa_Mapped") + column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED + member_expr = MemberExpr(column_descriptor, "_empty_constructor") + return CallExpr( + member_expr, + [expr], + [ARG_POS], + ["arg1"], + ) |