summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-02-13 20:37:12 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-02-13 20:37:12 +0000
commitd6b3c82b0c329730bcaff42b4bb39dba83acb536 (patch)
treed6b7f744a35c8d89615eeb0504ee7a4193f95642 /lib/sqlalchemy/ext
parent260ade78a70d51378de9e7b9456bfe6218859b6c (diff)
parente545298e35ea9f126054b337e4b5ba01988b29f7 (diff)
downloadsqlalchemy-d6b3c82b0c329730bcaff42b4bb39dba83acb536.tar.gz
Merge "establish mypy / typing approach for v2.0" into main
Diffstat (limited to 'lib/sqlalchemy/ext')
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py16
-rw-r--r--lib/sqlalchemy/ext/declarative/extensions.py2
-rw-r--r--lib/sqlalchemy/ext/mypy/apply.py33
-rw-r--r--lib/sqlalchemy/ext/mypy/decl_class.py2
-rw-r--r--lib/sqlalchemy/ext/mypy/infer.py40
-rw-r--r--lib/sqlalchemy/ext/mypy/names.py54
-rw-r--r--lib/sqlalchemy/ext/mypy/plugin.py13
-rw-r--r--lib/sqlalchemy/ext/mypy/util.py36
-rw-r--r--lib/sqlalchemy/ext/orderinglist.py35
9 files changed, 177 insertions, 54 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index e6a826c64..d5119907e 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -361,7 +361,7 @@ class AssociationProxyInstance:
prop = orm.class_mapper(owning_class).get_property(target_collection)
# this was never asserted before but this should be made clear.
- if not isinstance(prop, orm.RelationshipProperty):
+ if not isinstance(prop, orm.Relationship):
raise NotImplementedError(
"association proxy to a non-relationship "
"intermediary is not supported"
@@ -717,8 +717,8 @@ class AssociationProxyInstance:
"""Produce a proxied 'any' expression using EXISTS.
This expression will be a composed product
- using the :meth:`.RelationshipProperty.Comparator.any`
- and/or :meth:`.RelationshipProperty.Comparator.has`
+ using the :meth:`.Relationship.Comparator.any`
+ and/or :meth:`.Relationship.Comparator.has`
operators of the underlying proxied attributes.
"""
@@ -737,8 +737,8 @@ class AssociationProxyInstance:
"""Produce a proxied 'has' expression using EXISTS.
This expression will be a composed product
- using the :meth:`.RelationshipProperty.Comparator.any`
- and/or :meth:`.RelationshipProperty.Comparator.has`
+ using the :meth:`.Relationship.Comparator.any`
+ and/or :meth:`.Relationship.Comparator.has`
operators of the underlying proxied attributes.
"""
@@ -859,9 +859,9 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
"""Produce a proxied 'contains' expression using EXISTS.
This expression will be a composed product
- using the :meth:`.RelationshipProperty.Comparator.any`,
- :meth:`.RelationshipProperty.Comparator.has`,
- and/or :meth:`.RelationshipProperty.Comparator.contains`
+ using the :meth:`.Relationship.Comparator.any`,
+ :meth:`.Relationship.Comparator.has`,
+ and/or :meth:`.Relationship.Comparator.contains`
operators of the underlying proxied attributes.
"""
diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py
index 5aff4dfe2..470ff6ad8 100644
--- a/lib/sqlalchemy/ext/declarative/extensions.py
+++ b/lib/sqlalchemy/ext/declarative/extensions.py
@@ -378,7 +378,7 @@ class DeferredReflection:
metadata = mapper.class_.metadata
for rel in mapper._props.values():
if (
- isinstance(rel, relationships.RelationshipProperty)
+ isinstance(rel, relationships.Relationship)
and rel.secondary is not None
):
if isinstance(rel.secondary, Table):
diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py
index 99be194cd..4e244b5b9 100644
--- a/lib/sqlalchemy/ext/mypy/apply.py
+++ b/lib/sqlalchemy/ext/mypy/apply.py
@@ -36,6 +36,7 @@ from mypy.types import UnionType
from . import infer
from . import util
+from .names import expr_to_mapped_constructor
from .names import NAMED_TYPE_SQLA_MAPPED
@@ -117,6 +118,7 @@ def re_apply_declarative_assignments(
):
left_node = stmt.lvalues[0].node
+
python_type_for_type = mapped_attr_lookup[
stmt.lvalues[0].name
].type
@@ -142,7 +144,7 @@ def re_apply_declarative_assignments(
)
):
- python_type_for_type = (
+ new_python_type_for_type = (
infer.infer_type_from_right_hand_nameexpr(
api,
stmt,
@@ -152,19 +154,27 @@ def re_apply_declarative_assignments(
)
)
- if python_type_for_type is None or isinstance(
- python_type_for_type, UnboundType
+ if new_python_type_for_type is not None and not isinstance(
+ new_python_type_for_type, UnboundType
):
- continue
+ python_type_for_type = new_python_type_for_type
- # update the SQLAlchemyAttribute with the better information
- mapped_attr_lookup[
- stmt.lvalues[0].name
- ].type = python_type_for_type
+ # update the SQLAlchemyAttribute with the better
+ # information
+ mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type = python_type_for_type
- update_cls_metadata = True
+ update_cls_metadata = True
- if python_type_for_type is not None:
+ # for some reason if you have a Mapped type explicitly annotated,
+ # and here you set it again, mypy forgets how to do descriptors.
+ # no idea. 100% feeling around in the dark to see what sticks
+ if (
+ not isinstance(left_node.type, Instance)
+ or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED
+ ):
+ assert python_type_for_type is not None
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
)
@@ -202,6 +212,7 @@ def apply_type_to_mapped_statement(
assert isinstance(left_node, Var)
if left_hand_explicit_type is not None:
+ lvalue.is_inferred_def = False
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
)
@@ -224,7 +235,7 @@ def apply_type_to_mapped_statement(
# _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
# the original right-hand side is maintained so it gets type checked
# internally
- stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
+ stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue)
def add_additional_orm_attributes(
diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py
index c33c30e25..bd6c6f41e 100644
--- a/lib/sqlalchemy/ext/mypy/decl_class.py
+++ b/lib/sqlalchemy/ext/mypy/decl_class.py
@@ -337,7 +337,7 @@ def _scan_declarative_decorator_stmt(
# <attr> : Mapped[<typ>] =
# _sa_Mapped._empty_constructor(lambda: <function body>)
# the function body is maintained so it gets type checked internally
- rvalue = util.expr_to_mapped_constructor(
+ rvalue = names.expr_to_mapped_constructor(
LambdaExpr(stmt.func.arguments, stmt.func.body)
)
diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py
index 3cd946e04..6a5e99e48 100644
--- a/lib/sqlalchemy/ext/mypy/infer.py
+++ b/lib/sqlalchemy/ext/mypy/infer.py
@@ -42,11 +42,13 @@ def infer_type_from_right_hand_nameexpr(
left_hand_explicit_type: Optional[ProperType],
infer_from_right_side: RefExpr,
) -> Optional[ProperType]:
-
type_id = names.type_id_for_callee(infer_from_right_side)
-
if type_id is None:
return None
+ elif type_id is names.MAPPED:
+ python_type_for_type = _infer_type_from_mapped(
+ api, stmt, node, left_hand_explicit_type, infer_from_right_side
+ )
elif type_id is names.COLUMN:
python_type_for_type = _infer_type_from_decl_column(
api, stmt, node, left_hand_explicit_type
@@ -245,7 +247,7 @@ def _infer_type_from_decl_composite_property(
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
- """Infer the type of mapping from a CompositeProperty."""
+ """Infer the type of mapping from a Composite."""
assert isinstance(stmt.rvalue, CallExpr)
target_cls_arg = stmt.rvalue.args[0]
@@ -271,6 +273,38 @@ def _infer_type_from_decl_composite_property(
return python_type_for_type
+def _infer_type_from_mapped(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+ infer_from_right_side: RefExpr,
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a right side expression
+ that returns Mapped.
+
+
+ """
+ assert isinstance(stmt.rvalue, CallExpr)
+
+ # (Pdb) print(stmt.rvalue.callee)
+ # NameExpr(query_expression [sqlalchemy.orm._orm_constructors.query_expression]) # noqa: E501
+ # (Pdb) stmt.rvalue.callee.node
+ # <mypy.nodes.FuncDef object at 0x7f8d92fb5940>
+ # (Pdb) stmt.rvalue.callee.node.type
+ # def [_T] (default_expr: sqlalchemy.sql.elements.ColumnElement[_T`-1] =) -> sqlalchemy.orm.base.Mapped[_T`-1] # noqa: E501
+ # sqlalchemy.orm.base.Mapped[_T`-1]
+ # the_mapped_type = stmt.rvalue.callee.node.type.ret_type
+
+ # TODO: look at generic ref and either use that,
+ # or reconcile w/ what's present, etc.
+ the_mapped_type = util.type_for_callee(infer_from_right_side) # noqa
+
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
def _infer_type_from_decl_column_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py
index b6f911979..ad4449e5b 100644
--- a/lib/sqlalchemy/ext/mypy/names.py
+++ b/lib/sqlalchemy/ext/mypy/names.py
@@ -12,11 +12,14 @@ from typing import Set
from typing import Tuple
from typing import Union
+from mypy.nodes import ARG_POS
+from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Expression
from mypy.nodes import FuncDef
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
+from mypy.nodes import OverloadedFuncDef
from mypy.nodes import SymbolNode
from mypy.nodes import TypeAlias
from mypy.nodes import TypeInfo
@@ -51,7 +54,7 @@ QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore
NAMED_TYPE_BUILTINS_OBJECT = "builtins.object"
NAMED_TYPE_BUILTINS_STR = "builtins.str"
NAMED_TYPE_BUILTINS_LIST = "builtins.list"
-NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped"
+NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped"
_lookup: Dict[str, Tuple[int, Set[str]]] = {
"Column": (
@@ -61,11 +64,11 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = {
"sqlalchemy.sql.Column",
},
),
- "RelationshipProperty": (
+ "Relationship": (
RELATIONSHIP,
{
- "sqlalchemy.orm.relationships.RelationshipProperty",
- "sqlalchemy.orm.RelationshipProperty",
+ "sqlalchemy.orm.relationships.Relationship",
+ "sqlalchemy.orm.Relationship",
},
),
"registry": (
@@ -82,18 +85,18 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = {
"sqlalchemy.orm.ColumnProperty",
},
),
- "SynonymProperty": (
+ "Synonym": (
SYNONYM_PROPERTY,
{
- "sqlalchemy.orm.descriptor_props.SynonymProperty",
- "sqlalchemy.orm.SynonymProperty",
+ "sqlalchemy.orm.descriptor_props.Synonym",
+ "sqlalchemy.orm.Synonym",
},
),
- "CompositeProperty": (
+ "Composite": (
COMPOSITE_PROPERTY,
{
- "sqlalchemy.orm.descriptor_props.CompositeProperty",
- "sqlalchemy.orm.CompositeProperty",
+ "sqlalchemy.orm.descriptor_props.Composite",
+ "sqlalchemy.orm.Composite",
},
),
"MapperProperty": (
@@ -159,7 +162,10 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = {
),
"query_expression": (
QUERY_EXPRESSION,
- {"sqlalchemy.orm.query_expression"},
+ {
+ "sqlalchemy.orm.query_expression",
+ "sqlalchemy.orm._orm_constructors.query_expression",
+ },
),
}
@@ -209,7 +215,19 @@ def type_id_for_unbound_type(
def type_id_for_callee(callee: Expression) -> Optional[int]:
if isinstance(callee, (MemberExpr, NameExpr)):
- if isinstance(callee.node, FuncDef):
+ if isinstance(callee.node, OverloadedFuncDef):
+ if (
+ callee.node.impl
+ and callee.node.impl.type
+ and isinstance(callee.node.impl.type, CallableType)
+ ):
+ ret_type = get_proper_type(callee.node.impl.type.ret_type)
+
+ if isinstance(ret_type, Instance):
+ return type_id_for_fullname(ret_type.type.fullname)
+
+ return None
+ elif isinstance(callee.node, FuncDef):
if callee.node.type and isinstance(callee.node.type, CallableType):
ret_type = get_proper_type(callee.node.type.ret_type)
@@ -251,3 +269,15 @@ def type_id_for_fullname(fullname: str) -> Optional[int]:
return type_id
else:
return None
+
+
+def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
+ column_descriptor = NameExpr("__sa_Mapped")
+ column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED
+ member_expr = MemberExpr(column_descriptor, "_empty_constructor")
+ return CallExpr(
+ member_expr,
+ [expr],
+ [ARG_POS],
+ ["arg1"],
+ )
diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py
index 0a21feb51..c9520fef3 100644
--- a/lib/sqlalchemy/ext/mypy/plugin.py
+++ b/lib/sqlalchemy/ext/mypy/plugin.py
@@ -40,6 +40,19 @@ from . import decl_class
from . import names
from . import util
+try:
+ import sqlalchemy_stubs # noqa
+except ImportError:
+ pass
+else:
+ import sqlalchemy
+
+ raise ImportError(
+ f"The SQLAlchemy mypy plugin in SQLAlchemy "
+ f"{sqlalchemy.__version__} does not work with sqlalchemy-stubs or "
+ "sqlalchemy2-stubs installed"
+ )
+
class SQLAlchemyPlugin(Plugin):
def get_dynamic_class_hook(
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
index fa42074c3..741772eac 100644
--- a/lib/sqlalchemy/ext/mypy/util.py
+++ b/lib/sqlalchemy/ext/mypy/util.py
@@ -10,24 +10,27 @@ from typing import Type as TypingType
from typing import TypeVar
from typing import Union
-from mypy.nodes import ARG_POS
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
from mypy.nodes import Expression
+from mypy.nodes import FuncDef
from mypy.nodes import IfStmt
from mypy.nodes import JsonDict
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import Statement
from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeAlias
from mypy.nodes import TypeInfo
from mypy.plugin import ClassDefContext
from mypy.plugin import DynamicClassDefContext
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import deserialize_and_fixup_type
from mypy.typeops import map_type_from_supertype
+from mypy.types import CallableType
+from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import Type
@@ -231,6 +234,25 @@ def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
yield stmt
+def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]:
+ if isinstance(callee, (MemberExpr, NameExpr)):
+ if isinstance(callee.node, FuncDef):
+ if callee.node.type and isinstance(callee.node.type, CallableType):
+ ret_type = get_proper_type(callee.node.type.ret_type)
+
+ if isinstance(ret_type, Instance):
+ return ret_type
+
+ return None
+ elif isinstance(callee.node, TypeAlias):
+ target_type = get_proper_type(callee.node.target)
+ if isinstance(target_type, Instance):
+ return target_type
+ elif isinstance(callee.node, TypeInfo):
+ return callee.node
+ return None
+
+
def unbound_to_instance(
api: SemanticAnalyzerPluginInterface, typ: Type
) -> Type:
@@ -290,15 +312,3 @@ def info_for_cls(
return sym.node
return cls.info
-
-
-def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
- column_descriptor = NameExpr("__sa_Mapped")
- column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
- member_expr = MemberExpr(column_descriptor, "_empty_constructor")
- return CallExpr(
- member_expr,
- [expr],
- [ARG_POS],
- ["arg1"],
- )
diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py
index 5a327d1a5..5384851b1 100644
--- a/lib/sqlalchemy/ext/orderinglist.py
+++ b/lib/sqlalchemy/ext/orderinglist.py
@@ -119,14 +119,28 @@ start numbering at 1 or some other integer, provide ``count_from=1``.
"""
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import TypeVar
+
from ..orm.collections import collection
from ..orm.collections import collection_adapter
+_T = TypeVar("_T")
+OrderingFunc = Callable[[int, Sequence[_T]], int]
+
__all__ = ["ordering_list"]
-def ordering_list(attr, count_from=None, **kw):
+def ordering_list(
+ attr: str,
+ count_from: Optional[int] = None,
+ ordering_func: Optional[OrderingFunc] = None,
+ reorder_on_append: bool = False,
+) -> Callable[[], "OrderingList"]:
"""Prepares an :class:`OrderingList` factory for use in mapper definitions.
Returns an object suitable for use as an argument to a Mapper
@@ -157,7 +171,11 @@ def ordering_list(attr, count_from=None, **kw):
"""
- kw = _unsugar_count_from(count_from=count_from, **kw)
+ kw = _unsugar_count_from(
+ count_from=count_from,
+ ordering_func=ordering_func,
+ reorder_on_append=reorder_on_append,
+ )
return lambda: OrderingList(attr, **kw)
@@ -207,7 +225,7 @@ def _unsugar_count_from(**kw):
return kw
-class OrderingList(list):
+class OrderingList(List[_T]):
"""A custom list that manages position information for its children.
The :class:`.OrderingList` object is normally set up using the
@@ -216,8 +234,15 @@ class OrderingList(list):
"""
+ ordering_attr: str
+ ordering_func: OrderingFunc
+ reorder_on_append: bool
+
def __init__(
- self, ordering_attr=None, ordering_func=None, reorder_on_append=False
+ self,
+ ordering_attr: Optional[str] = None,
+ ordering_func: Optional[OrderingFunc] = None,
+ reorder_on_append: bool = False,
):
"""A custom list that manages position information for its children.
@@ -282,7 +307,7 @@ class OrderingList(list):
def _set_order_value(self, entity, value):
setattr(entity, self.ordering_attr, value)
- def reorder(self):
+ def reorder(self) -> None:
"""Synchronize ordering for the entire collection.
Sweeps through the list and ensures that each object has accurate