summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/names.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/names.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/names.py54
1 files changed, 42 insertions, 12 deletions
diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py
index b6f911979..ad4449e5b 100644
--- a/lib/sqlalchemy/ext/mypy/names.py
+++ b/lib/sqlalchemy/ext/mypy/names.py
@@ -12,11 +12,14 @@ from typing import Set
from typing import Tuple
from typing import Union
+from mypy.nodes import ARG_POS
+from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Expression
from mypy.nodes import FuncDef
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
+from mypy.nodes import OverloadedFuncDef
from mypy.nodes import SymbolNode
from mypy.nodes import TypeAlias
from mypy.nodes import TypeInfo
@@ -51,7 +54,7 @@ QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore
NAMED_TYPE_BUILTINS_OBJECT = "builtins.object"
NAMED_TYPE_BUILTINS_STR = "builtins.str"
NAMED_TYPE_BUILTINS_LIST = "builtins.list"
-NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped"
+NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped"
_lookup: Dict[str, Tuple[int, Set[str]]] = {
"Column": (
@@ -61,11 +64,11 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = {
"sqlalchemy.sql.Column",
},
),
- "RelationshipProperty": (
+ "Relationship": (
RELATIONSHIP,
{
- "sqlalchemy.orm.relationships.RelationshipProperty",
- "sqlalchemy.orm.RelationshipProperty",
+ "sqlalchemy.orm.relationships.Relationship",
+ "sqlalchemy.orm.Relationship",
},
),
"registry": (
@@ -82,18 +85,18 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = {
"sqlalchemy.orm.ColumnProperty",
},
),
- "SynonymProperty": (
+ "Synonym": (
SYNONYM_PROPERTY,
{
- "sqlalchemy.orm.descriptor_props.SynonymProperty",
- "sqlalchemy.orm.SynonymProperty",
+ "sqlalchemy.orm.descriptor_props.Synonym",
+ "sqlalchemy.orm.Synonym",
},
),
- "CompositeProperty": (
+ "Composite": (
COMPOSITE_PROPERTY,
{
- "sqlalchemy.orm.descriptor_props.CompositeProperty",
- "sqlalchemy.orm.CompositeProperty",
+ "sqlalchemy.orm.descriptor_props.Composite",
+ "sqlalchemy.orm.Composite",
},
),
"MapperProperty": (
@@ -159,7 +162,10 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = {
),
"query_expression": (
QUERY_EXPRESSION,
- {"sqlalchemy.orm.query_expression"},
+ {
+ "sqlalchemy.orm.query_expression",
+ "sqlalchemy.orm._orm_constructors.query_expression",
+ },
),
}
@@ -209,7 +215,19 @@ def type_id_for_unbound_type(
def type_id_for_callee(callee: Expression) -> Optional[int]:
if isinstance(callee, (MemberExpr, NameExpr)):
- if isinstance(callee.node, FuncDef):
+ if isinstance(callee.node, OverloadedFuncDef):
+ if (
+ callee.node.impl
+ and callee.node.impl.type
+ and isinstance(callee.node.impl.type, CallableType)
+ ):
+ ret_type = get_proper_type(callee.node.impl.type.ret_type)
+
+ if isinstance(ret_type, Instance):
+ return type_id_for_fullname(ret_type.type.fullname)
+
+ return None
+ elif isinstance(callee.node, FuncDef):
if callee.node.type and isinstance(callee.node.type, CallableType):
ret_type = get_proper_type(callee.node.type.ret_type)
@@ -251,3 +269,15 @@ def type_id_for_fullname(fullname: str) -> Optional[int]:
return type_id
else:
return None
+
+
+def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
+ column_descriptor = NameExpr("__sa_Mapped")
+ column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED
+ member_expr = MemberExpr(column_descriptor, "_empty_constructor")
+ return CallExpr(
+ member_expr,
+ [expr],
+ [ARG_POS],
+ ["arg1"],
+ )