diff options
Diffstat (limited to 'lib/sqlalchemy/orm/relationships.py')
-rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 1004 |
1 files changed, 664 insertions, 340 deletions
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 8273775ae..1186f0f54 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -17,13 +17,23 @@ from __future__ import annotations import collections from collections import abc +import dataclasses import re import typing from typing import Any from typing import Callable +from typing import cast +from typing import Collection from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NamedTuple +from typing import NoReturn from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TypeVar @@ -32,14 +42,19 @@ import weakref from . import attributes from . import strategy_options +from ._typing import insp_is_aliased_class +from ._typing import is_has_collection_adapter from .base import _is_mapped_class from .base import class_mapper +from .base import LoaderCallableStatus +from .base import PassiveFlag from .base import state_str from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY from .interfaces import MANYTOONE from .interfaces import ONETOMANY from .interfaces import PropComparator +from .interfaces import RelationshipDirection from .interfaces import StrategizedProperty from .util import _extract_mapped_subtype from .util import _orm_annotate @@ -60,6 +75,7 @@ from ..sql import visitors from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _HasClauseElement from ..sql.elements import ColumnClause +from ..sql.elements import ColumnElement from ..sql.util import _deep_deannotate from ..sql.util import _shallow_annotate from ..sql.util import adapt_criterion_to_null @@ -71,15 +87,42 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict from ._typing import _InternalEntityType + from ._typing import _O + from ._typing import _RegistryType + from .clsregistry import _class_resolver + from .clsregistry import _ModNS + from .dependency import DependencyProcessor from .mapper import Mapper + from .query import Query + from .session import Session + from .state import InstanceState + from .strategies import LazyLoader from .util import AliasedClass from .util import AliasedInsp - from ..sql.elements import ColumnElement + from ..sql._typing import _CoreAdapterProto + from ..sql._typing import _EquivalentColumnMap + from ..sql._typing import _InfoType + from ..sql.annotation import _AnnotationDict + from ..sql.elements import BinaryExpression + from ..sql.elements import BindParameter + from ..sql.elements import ClauseElement + from ..sql.schema import Table + from ..sql.selectable import FromClause + from ..util.typing import _AnnotationScanType + from ..util.typing import RODescriptorReference _T = TypeVar("_T", bound=Any) +_T1 = TypeVar("_T1", bound=Any) +_T2 = TypeVar("_T2", bound=Any) + _PT = TypeVar("_PT", bound=Any) +_PT2 = TypeVar("_PT2", bound=Any) + _RelationshipArgumentType = Union[ str, @@ -111,7 +154,10 @@ _RelationshipJoinConditionArgument = Union[ str, _ColumnExpressionArgument[bool] ] _ORMOrderByArgument = Union[ - Literal[False], str, _ColumnExpressionArgument[Any] + Literal[False], + str, + _ColumnExpressionArgument[Any], + Iterable[Union[str, _ColumnExpressionArgument[Any]]], ] _ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] _ORMColCollectionArgument = Union[ @@ -120,7 +166,19 @@ _ORMColCollectionArgument = Union[ ] -def remote(expr): +_CEA = TypeVar("_CEA", bound=_ColumnExpressionArgument[Any]) + +_CE = TypeVar("_CE", bound="ColumnElement[Any]") + + +_ColumnPairIterable = Iterable[Tuple[ColumnElement[Any], ColumnElement[Any]]] + +_ColumnPairs = Sequence[Tuple[ColumnElement[Any], ColumnElement[Any]]] + +_MutableColumnPairs = List[Tuple[ColumnElement[Any], ColumnElement[Any]]] + + +def remote(expr: _CEA) -> _CEA: """Annotate a portion of a primaryjoin expression with a 'remote' annotation. @@ -134,12 +192,12 @@ def remote(expr): :func:`.foreign` """ - return _annotate_columns( + return _annotate_columns( # type: ignore coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True} ) -def foreign(expr): +def foreign(expr: _CEA) -> _CEA: """Annotate a portion of a primaryjoin expression with a 'foreign' annotation. @@ -154,11 +212,71 @@ def foreign(expr): """ - return _annotate_columns( + return _annotate_columns( # type: ignore coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True} ) +@dataclasses.dataclass +class _RelationshipArg(Generic[_T1, _T2]): + """stores a user-defined parameter value that must be resolved and + parsed later at mapper configuration time. + + """ + + __slots__ = "name", "argument", "resolved" + name: str + argument: _T1 + resolved: Optional[_T2] + + def _is_populated(self) -> bool: + return self.argument is not None + + def _resolve_against_registry( + self, clsregistry_resolver: Callable[[str, bool], _class_resolver] + ) -> None: + attr_value = self.argument + + if isinstance(attr_value, str): + self.resolved = clsregistry_resolver( + attr_value, self.name == "secondary" + )() + elif callable(attr_value) and not _is_mapped_class(attr_value): + self.resolved = attr_value() + else: + self.resolved = attr_value + + +class _RelationshipArgs(NamedTuple): + """stores user-passed parameters that are resolved at mapper configuration + time. + + """ + + secondary: _RelationshipArg[ + Optional[Union[FromClause, str]], + Optional[FromClause], + ] + primaryjoin: _RelationshipArg[ + Optional[_RelationshipJoinConditionArgument], + Optional[ColumnElement[Any]], + ] + secondaryjoin: _RelationshipArg[ + Optional[_RelationshipJoinConditionArgument], + Optional[ColumnElement[Any]], + ] + order_by: _RelationshipArg[ + _ORMOrderByArgument, + Union[Literal[None, False], Tuple[ColumnElement[Any], ...]], + ] + foreign_keys: _RelationshipArg[ + Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]] + ] + remote_side: _RelationshipArg[ + Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]] + ] + + @log.class_logger class Relationship( _IntrospectsAnnotations, StrategizedProperty[_T], log.Identified @@ -184,6 +302,10 @@ class Relationship( _links_to_entity = True _is_relationship = True + _overlaps: Sequence[str] + + _lazy_strategy: LazyLoader + _persistence_only = dict( passive_deletes=False, passive_updates=True, @@ -192,56 +314,87 @@ class Relationship( cascade_backrefs=False, ) - _dependency_processor = None + _dependency_processor: Optional[DependencyProcessor] = None + + primaryjoin: ColumnElement[bool] + secondaryjoin: Optional[ColumnElement[bool]] + secondary: Optional[FromClause] + _join_condition: JoinCondition + order_by: Union[Literal[False], Tuple[ColumnElement[Any], ...]] + + _user_defined_foreign_keys: Set[ColumnElement[Any]] + _calculated_foreign_keys: Set[ColumnElement[Any]] + + remote_side: Set[ColumnElement[Any]] + local_columns: Set[ColumnElement[Any]] + + synchronize_pairs: _ColumnPairs + secondary_synchronize_pairs: Optional[_ColumnPairs] + + local_remote_pairs: Optional[_ColumnPairs] + + direction: RelationshipDirection + + _init_args: _RelationshipArgs def __init__( self, argument: Optional[_RelationshipArgumentType[_T]] = None, - secondary=None, + secondary: Optional[Union[FromClause, str]] = None, *, - uselist=None, - collection_class=None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, - order_by=False, - backref=None, - cascade_backrefs=False, - overlaps=None, - post_update=False, - cascade="save-update, merge", - viewonly=False, + uselist: Optional[bool] = None, + collection_class: Optional[ + Union[Type[Collection[Any]], Callable[[], Collection[Any]]] + ] = None, + primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + back_populates: Optional[str] = None, + order_by: _ORMOrderByArgument = False, + backref: Optional[_ORMBackrefArgument] = None, + overlaps: Optional[str] = None, + post_update: bool = False, + cascade: str = "save-update, merge", + viewonly: bool = False, lazy: _LazyLoadArgumentType = "select", - passive_deletes=False, - passive_updates=True, - active_history=False, - enable_typechecks=True, - foreign_keys=None, - remote_side=None, - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - load_on_pending=False, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - doc=None, - bake_queries=True, - _local_remote_pairs=None, - _legacy_inactive_history_style=False, + passive_deletes: Union[Literal["all"], bool] = False, + passive_updates: bool = True, + active_history: bool = False, + enable_typechecks: bool = True, + foreign_keys: Optional[_ORMColCollectionArgument] = None, + remote_side: Optional[_ORMColCollectionArgument] = None, + join_depth: Optional[int] = None, + comparator_factory: Optional[ + Type[Relationship.Comparator[Any]] + ] = None, + single_parent: bool = False, + innerjoin: bool = False, + distinct_target_key: Optional[bool] = None, + load_on_pending: bool = False, + query_class: Optional[Type[Query[Any]]] = None, + info: Optional[_InfoType] = None, + omit_join: Literal[None, False] = None, + sync_backref: Optional[bool] = None, + doc: Optional[str] = None, + bake_queries: Literal[True] = True, + cascade_backrefs: Literal[False] = False, + _local_remote_pairs: Optional[_ColumnPairs] = None, + _legacy_inactive_history_style: bool = False, ): super(Relationship, self).__init__() self.uselist = uselist self.argument = argument - self.secondary = secondary - self.primaryjoin = primaryjoin - self.secondaryjoin = secondaryjoin + + self._init_args = _RelationshipArgs( + _RelationshipArg("secondary", secondary, None), + _RelationshipArg("primaryjoin", primaryjoin, None), + _RelationshipArg("secondaryjoin", secondaryjoin, None), + _RelationshipArg("order_by", order_by, None), + _RelationshipArg("foreign_keys", foreign_keys, None), + _RelationshipArg("remote_side", remote_side, None), + ) + self.post_update = post_update - self.direction = None self.viewonly = viewonly if viewonly: self._warn_for_persistence_only_flags( @@ -258,7 +411,6 @@ class Relationship( self.sync_backref = sync_backref self.lazy = lazy self.single_parent = single_parent - self._user_defined_foreign_keys = foreign_keys self.collection_class = collection_class self.passive_deletes = passive_deletes @@ -269,7 +421,6 @@ class Relationship( ) self.passive_updates = passive_updates - self.remote_side = remote_side self.enable_typechecks = enable_typechecks self.query_class = query_class self.innerjoin = innerjoin @@ -292,23 +443,22 @@ class Relationship( self.local_remote_pairs = _local_remote_pairs self.load_on_pending = load_on_pending self.comparator_factory = comparator_factory or Relationship.Comparator - self.comparator = self.comparator_factory(self, None) util.set_creation_order(self) if info is not None: - self.info = info + self.info.update(info) self.strategy_key = (("lazy", self.lazy),) - self._reverse_property = set() + self._reverse_property: Set[Relationship[Any]] = set() + if overlaps: - self._overlaps = set(re.split(r"\s*,\s*", overlaps)) + self._overlaps = set(re.split(r"\s*,\s*", overlaps)) # type: ignore # noqa: E501 else: self._overlaps = () - self.cascade = cascade - - self.order_by = order_by + # mypy ignoring the @property setter + self.cascade = cascade # type: ignore self.back_populates = back_populates @@ -322,7 +472,7 @@ class Relationship( else: self.backref = backref - def _warn_for_persistence_only_flags(self, **kw): + def _warn_for_persistence_only_flags(self, **kw: Any) -> None: for k, v in kw.items(): if v != self._persistence_only[k]: # we are warning here rather than warn deprecated as this is a @@ -340,7 +490,7 @@ class Relationship( "in a future release." % (k,) ) - def instrument_class(self, mapper): + def instrument_class(self, mapper: Mapper[Any]) -> None: attributes.register_descriptor( mapper.class_, self.key, @@ -378,13 +528,16 @@ class Relationship( "_extra_criteria", ) + prop: RODescriptorReference[Relationship[_PT]] + _of_type: Optional[_EntityType[_PT]] + def __init__( self, - prop, - parentmapper, - adapt_to_entity=None, - of_type=None, - extra_criteria=(), + prop: Relationship[_PT], + parentmapper: _InternalEntityType[Any], + adapt_to_entity: Optional[AliasedInsp[Any]] = None, + of_type: Optional[_EntityType[_PT]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), ): """Construction of :class:`.Relationship.Comparator` is internal to the ORM's attribute mechanics. @@ -399,15 +552,17 @@ class Relationship( self._of_type = None self._extra_criteria = extra_criteria - def adapt_to_entity(self, adapt_to_entity): + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> Relationship.Comparator[Any]: return self.__class__( - self.property, + self.prop, self._parententity, adapt_to_entity=adapt_to_entity, of_type=self._of_type, ) - entity: _InternalEntityType + entity: _InternalEntityType[_PT] """The target entity referred to by this :class:`.Relationship.Comparator`. @@ -419,7 +574,7 @@ class Relationship( """ - mapper: Mapper[Any] + mapper: Mapper[_PT] """The target :class:`_orm.Mapper` referred to by this :class:`.Relationship.Comparator`. @@ -428,22 +583,22 @@ class Relationship( """ - def _memoized_attr_entity(self) -> _InternalEntityType: + def _memoized_attr_entity(self) -> _InternalEntityType[_PT]: if self._of_type: - return inspect(self._of_type) + return inspect(self._of_type) # type: ignore else: return self.prop.entity - def _memoized_attr_mapper(self) -> Mapper[Any]: + def _memoized_attr_mapper(self) -> Mapper[_PT]: return self.entity.mapper - def _source_selectable(self): + def _source_selectable(self) -> FromClause: if self._adapt_to_entity: return self._adapt_to_entity.selectable else: return self.property.parent._with_polymorphic_selectable - def __clause_element__(self): + def __clause_element__(self) -> ColumnElement[bool]: adapt_from = self._source_selectable() if self._of_type: of_type_entity = inspect(self._of_type) @@ -457,7 +612,7 @@ class Relationship( dest, secondary, target_adapter, - ) = self.property._create_joins( + ) = self.prop._create_joins( source_selectable=adapt_from, source_polymorphic=True, of_type_entity=of_type_entity, @@ -469,7 +624,7 @@ class Relationship( else: return pj - def of_type(self, cls): + def of_type(self, class_: _EntityType[_PT]) -> PropComparator[_PT]: r"""Redefine this object in terms of a polymorphic subclass. See :meth:`.PropComparator.of_type` for an example. @@ -477,16 +632,16 @@ class Relationship( """ return Relationship.Comparator( - self.property, + self.prop, self._parententity, adapt_to_entity=self._adapt_to_entity, - of_type=cls, + of_type=class_, extra_criteria=self._extra_criteria, ) def and_( self, *criteria: _ColumnExpressionArgument[bool] - ) -> PropComparator[bool]: + ) -> PropComparator[Any]: """Add AND criteria. See :meth:`.PropComparator.and_` for an example. @@ -500,14 +655,14 @@ class Relationship( ) return Relationship.Comparator( - self.property, + self.prop, self._parententity, adapt_to_entity=self._adapt_to_entity, of_type=self._of_type, extra_criteria=self._extra_criteria + exprs, ) - def in_(self, other): + def in_(self, other: Any) -> NoReturn: """Produce an IN clause - this is not implemented for :func:`_orm.relationship`-based attributes at this time. @@ -522,7 +677,7 @@ class Relationship( # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore - def __eq__(self, other): + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 """Implement the ``==`` operator. In a many-to-one context, such as:: @@ -559,7 +714,7 @@ class Relationship( or many-to-many context produce a NOT EXISTS clause. """ - if isinstance(other, (util.NoneType, expression.Null)): + if other is None or isinstance(other, expression.Null): if self.property.direction in [ONETOMANY, MANYTOMANY]: return ~self._criterion_exists() else: @@ -585,8 +740,18 @@ class Relationship( criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, ) -> Exists: + + where_criteria = ( + coercions.expect(roles.WhereHavingRole, criterion) + if criterion is not None + else None + ) + if getattr(self, "_of_type", None): - info = inspect(self._of_type) + info: Optional[_InternalEntityType[Any]] = inspect( + self._of_type + ) + assert info is not None target_mapper, to_selectable, is_aliased_class = ( info.mapper, info.selectable, @@ -597,10 +762,10 @@ class Relationship( single_crit = target_mapper._single_table_criterion if single_crit is not None: - if criterion is not None: - criterion = single_crit & criterion + if where_criteria is not None: + where_criteria = single_crit & where_criteria else: - criterion = single_crit + where_criteria = single_crit else: is_aliased_class = False to_selectable = None @@ -624,10 +789,10 @@ class Relationship( for k in kwargs: crit = getattr(self.property.mapper.class_, k) == kwargs[k] - if criterion is None: - criterion = crit + if where_criteria is None: + where_criteria = crit else: - criterion = criterion & crit + where_criteria = where_criteria & crit # annotate the *local* side of the join condition, in the case # of pj + sj this is the full primaryjoin, in the case of just @@ -638,24 +803,24 @@ class Relationship( j = _orm_annotate(pj, exclude=self.property.remote_side) if ( - criterion is not None + where_criteria is not None and target_adapter and not is_aliased_class ): # limit this adapter to annotated only? - criterion = target_adapter.traverse(criterion) + where_criteria = target_adapter.traverse(where_criteria) # only have the "joined left side" of what we # return be subject to Query adaption. The right # side of it is used for an exists() subquery and # should not correlate or otherwise reach out # to anything in the enclosing query. - if criterion is not None: - criterion = criterion._annotate( + if where_criteria is not None: + where_criteria = where_criteria._annotate( {"no_replacement_traverse": True} ) - crit = j & sql.True_._ifnone(criterion) + crit = j & sql.True_._ifnone(where_criteria) if secondary is not None: ex = ( @@ -673,7 +838,11 @@ class Relationship( ) return ex - def any(self, criterion=None, **kwargs): + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: """Produce an expression that tests a collection against particular criterion, using EXISTS. @@ -722,7 +891,11 @@ class Relationship( return self._criterion_exists(criterion, **kwargs) - def has(self, criterion=None, **kwargs): + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: """Produce an expression that tests a scalar reference against particular criterion, using EXISTS. @@ -756,7 +929,9 @@ class Relationship( ) return self._criterion_exists(criterion, **kwargs) - def contains(self, other, **kwargs): + def contains( + self, other: _ColumnExpressionArgument[Any], **kwargs: Any + ) -> ColumnElement[bool]: """Return a simple expression that tests a collection for containment of a particular item. @@ -815,38 +990,45 @@ class Relationship( kwargs may be ignored by this operator but are required for API conformance. """ - if not self.property.uselist: + if not self.prop.uselist: raise sa_exc.InvalidRequestError( "'contains' not implemented for scalar " "attributes. Use ==" ) - clause = self.property._optimized_compare( + + clause = self.prop._optimized_compare( other, adapt_source=self.adapter ) - if self.property.secondaryjoin is not None: + if self.prop.secondaryjoin is not None: clause.negation_clause = self.__negated_contains_or_equals( other ) return clause - def __negated_contains_or_equals(self, other): - if self.property.direction == MANYTOONE: + def __negated_contains_or_equals( + self, other: Any + ) -> ColumnElement[bool]: + if self.prop.direction == MANYTOONE: state = attributes.instance_state(other) - def state_bindparam(local_col, state, remote_col): + def state_bindparam( + local_col: ColumnElement[Any], + state: InstanceState[Any], + remote_col: ColumnElement[Any], + ) -> BindParameter[Any]: dict_ = state.dict return sql.bindparam( local_col.key, type_=local_col.type, unique=True, - callable_=self.property._get_attr_w_warn_on_none( - self.property.mapper, state, dict_, remote_col + callable_=self.prop._get_attr_w_warn_on_none( + self.prop.mapper, state, dict_, remote_col ), ) - def adapt(col): + def adapt(col: _CE) -> _CE: if self.adapter: return self.adapter(col) else: @@ -876,7 +1058,7 @@ class Relationship( return ~self._criterion_exists(criterion) - def __ne__(self, other): + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 """Implement the ``!=`` operator. In a many-to-one context, such as:: @@ -915,7 +1097,7 @@ class Relationship( or many-to-many context produce an EXISTS clause. """ - if isinstance(other, (util.NoneType, expression.Null)): + if other is None or isinstance(other, expression.Null): if self.property.direction == MANYTOONE: return _orm_annotate( ~self.property._optimized_compare( @@ -934,12 +1116,10 @@ class Relationship( else: return _orm_annotate(self.__negated_contains_or_equals(other)) - def _memoized_attr_property(self): + def _memoized_attr_property(self) -> Relationship[_PT]: self.prop.parent._check_configure() return self.prop - comparator: Comparator[_T] - def _with_parent( self, instance: object, @@ -947,10 +1127,11 @@ class Relationship( from_entity: Optional[_EntityType[Any]] = None, ) -> ColumnElement[bool]: assert instance is not None - adapt_source = None + adapt_source: Optional[_CoreAdapterProto] = None if from_entity is not None: - insp = inspect(from_entity) - if insp.is_aliased_class: + insp: Optional[_InternalEntityType[Any]] = inspect(from_entity) + assert insp is not None + if insp_is_aliased_class(insp): adapt_source = insp._adapter.adapt_clause return self._optimized_compare( instance, @@ -961,11 +1142,11 @@ class Relationship( def _optimized_compare( self, - state, - value_is_parent=False, - adapt_source=None, - alias_secondary=True, - ): + state: Any, + value_is_parent: bool = False, + adapt_source: Optional[_CoreAdapterProto] = None, + alias_secondary: bool = True, + ) -> ColumnElement[bool]: if state is not None: try: state = inspect(state) @@ -1005,7 +1186,7 @@ class Relationship( dict_ = attributes.instance_dict(state.obj()) - def visit_bindparam(bindparam): + def visit_bindparam(bindparam: BindParameter[Any]) -> None: if bindparam._identifying_key in bind_to_col: bindparam.callable = self._get_attr_w_warn_on_none( mapper, @@ -1027,7 +1208,13 @@ class Relationship( criterion = adapt_source(criterion) return criterion - def _get_attr_w_warn_on_none(self, mapper, state, dict_, column): + def _get_attr_w_warn_on_none( + self, + mapper: Mapper[Any], + state: InstanceState[Any], + dict_: _InstanceDict, + column: ColumnElement[Any], + ) -> Callable[[], Any]: """Create the callable that is used in a many-to-one expression. E.g.:: @@ -1077,9 +1264,14 @@ class Relationship( # this feature was added explicitly for use in this method. state._track_last_known_value(prop.key) - def _go(): - last_known = to_return = state._last_known_values[prop.key] - existing_is_available = last_known is not attributes.NO_VALUE + lkv_fixed = state._last_known_values + + def _go() -> Any: + assert lkv_fixed is not None + last_known = to_return = lkv_fixed[prop.key] + existing_is_available = ( + last_known is not LoaderCallableStatus.NO_VALUE + ) # we support that the value may have changed. so here we # try to get the most recent value including re-fetching. @@ -1089,19 +1281,19 @@ class Relationship( state, dict_, column, - passive=attributes.PASSIVE_OFF + passive=PassiveFlag.PASSIVE_OFF if state.persistent - else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK, + else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK, ) - if current_value is attributes.NEVER_SET: + if current_value is LoaderCallableStatus.NEVER_SET: if not existing_is_available: raise sa_exc.InvalidRequestError( "Can't resolve value for column %s on object " "%s; no value has been set for this column" % (column, state_str(state)) ) - elif current_value is attributes.PASSIVE_NO_RESULT: + elif current_value is LoaderCallableStatus.PASSIVE_NO_RESULT: if not existing_is_available: raise sa_exc.InvalidRequestError( "Can't resolve value for column %s on object " @@ -1121,7 +1313,11 @@ class Relationship( return _go - def _lazy_none_clause(self, reverse_direction=False, adapt_source=None): + def _lazy_none_clause( + self, + reverse_direction: bool = False, + adapt_source: Optional[_CoreAdapterProto] = None, + ) -> ColumnElement[bool]: if not reverse_direction: criterion, bind_to_col = ( self._lazy_strategy._lazywhere, @@ -1139,20 +1335,20 @@ class Relationship( criterion = adapt_source(criterion) return criterion - def __str__(self): + def __str__(self) -> str: return str(self.parent.class_.__name__) + "." + self.key def merge( self, - session, - source_state, - source_dict, - dest_state, - dest_dict, - load, - _recursive, - _resolve_conflict_map, - ): + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: if load: for r in self._reverse_property: @@ -1167,6 +1363,8 @@ class Relationship( if self.uselist: impl = source_state.get_impl(self.key) + + assert is_has_collection_adapter(impl) instances_iterable = impl.get_collection(source_state, source_dict) # if this is a CollectionAttributeImpl, then empty should @@ -1204,9 +1402,9 @@ class Relationship( for c in dest_list: coll.append_without_event(c) else: - dest_state.get_impl(self.key).set( - dest_state, dest_dict, dest_list, _adapt=False - ) + dest_impl = dest_state.get_impl(self.key) + assert is_has_collection_adapter(dest_impl) + dest_impl.set(dest_state, dest_dict, dest_list, _adapt=False) else: current = source_dict[self.key] if current is not None: @@ -1231,8 +1429,12 @@ class Relationship( ) def _value_as_iterable( - self, state, dict_, key, passive=attributes.PASSIVE_OFF - ): + self, + state: InstanceState[_O], + dict_: _InstanceDict, + key: str, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> Sequence[Tuple[InstanceState[_O], _O]]: """Return a list of tuples (state, obj) for the given key. @@ -1241,9 +1443,9 @@ class Relationship( impl = state.manager[key].impl x = impl.get(state, dict_, passive=passive) - if x is attributes.PASSIVE_NO_RESULT or x is None: + if x is LoaderCallableStatus.PASSIVE_NO_RESULT or x is None: return [] - elif hasattr(impl, "get_collection"): + elif is_has_collection_adapter(impl): return [ (attributes.instance_state(o), o) for o in impl.get_collection(state, dict_, x, passive=passive) @@ -1252,19 +1454,23 @@ class Relationship( return [(attributes.instance_state(x), x)] def cascade_iterator( - self, type_, state, dict_, visited_states, halt_on=None - ): + self, + type_: str, + state: InstanceState[Any], + dict_: _InstanceDict, + visited_states: Set[InstanceState[Any]], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[Tuple[Any, Mapper[Any], InstanceState[Any], _InstanceDict]]: # assert type_ in self._cascade # only actively lazy load on the 'delete' cascade if type_ != "delete" or self.passive_deletes: - passive = attributes.PASSIVE_NO_INITIALIZE + passive = PassiveFlag.PASSIVE_NO_INITIALIZE else: - passive = attributes.PASSIVE_OFF + passive = PassiveFlag.PASSIVE_OFF if type_ == "save-update": tuples = state.manager[self.key].impl.get_all_pending(state, dict_) - else: tuples = self._value_as_iterable( state, dict_, self.key, passive=passive @@ -1285,6 +1491,7 @@ class Relationship( # see [ticket:2229] continue + assert instance_state is not None instance_dict = attributes.instance_dict(c) if halt_on and halt_on(instance_state): @@ -1308,14 +1515,16 @@ class Relationship( yield c, instance_mapper, instance_state, instance_dict @property - def _effective_sync_backref(self): + def _effective_sync_backref(self) -> bool: if self.viewonly: return False else: return self.sync_backref is not False @staticmethod - def _check_sync_backref(rel_a, rel_b): + def _check_sync_backref( + rel_a: Relationship[Any], rel_b: Relationship[Any] + ) -> None: if rel_a.viewonly and rel_b.sync_backref: raise sa_exc.InvalidRequestError( "Relationship %s cannot specify sync_backref=True since %s " @@ -1328,7 +1537,7 @@ class Relationship( ): rel_b.sync_backref = False - def _add_reverse_property(self, key): + def _add_reverse_property(self, key: str) -> None: other = self.mapper.get_property(key, _configure_mappers=False) if not isinstance(other, Relationship): raise sa_exc.InvalidRequestError( @@ -1361,7 +1570,8 @@ class Relationship( ) if ( - self.direction in (ONETOMANY, MANYTOONE) + other._configure_started + and self.direction in (ONETOMANY, MANYTOONE) and self.direction == other.direction ): raise sa_exc.ArgumentError( @@ -1372,7 +1582,7 @@ class Relationship( ) @util.memoized_property - def entity(self) -> Union["Mapper", "AliasedInsp"]: + def entity(self) -> _InternalEntityType[_T]: """Return the target mapped entity, which is an inspect() of the class or aliased class that is referred towards. @@ -1388,7 +1598,7 @@ class Relationship( """ return self.entity.mapper - def do_init(self): + def do_init(self) -> None: self._check_conflicts() self._process_dependent_arguments() self._setup_entity() @@ -1399,14 +1609,16 @@ class Relationship( self._generate_backref() self._join_condition._warn_for_conflicting_sync_targets() super(Relationship, self).do_init() - self._lazy_strategy = self._get_strategy((("lazy", "select"),)) + self._lazy_strategy = cast( + "LazyLoader", self._get_strategy((("lazy", "select"),)) + ) - def _setup_registry_dependencies(self): + def _setup_registry_dependencies(self) -> None: self.parent.mapper.registry._set_depends_on( self.entity.mapper.registry ) - def _process_dependent_arguments(self): + def _process_dependent_arguments(self) -> None: """Convert incoming configuration arguments to their proper form. @@ -1417,78 +1629,80 @@ class Relationship( # accept callables for other attributes which may require # deferred initialization. This technique is used # by declarative "string configs" and some recipes. + init_args = self._init_args + for attr in ( "order_by", "primaryjoin", "secondaryjoin", "secondary", - "_user_defined_foreign_keys", + "foreign_keys", "remote_side", ): - attr_value = getattr(self, attr) - - if isinstance(attr_value, str): - setattr( - self, - attr, - self._clsregistry_resolve_arg( - attr_value, favor_tables=attr == "secondary" - )(), - ) - elif callable(attr_value) and not _is_mapped_class(attr_value): - setattr(self, attr, attr_value()) + + rel_arg = getattr(init_args, attr) + + rel_arg._resolve_against_registry(self._clsregistry_resolvers[1]) # remove "annotations" which are present if mapped class # descriptors are used to create the join expression. for attr in "primaryjoin", "secondaryjoin": - val = getattr(self, attr) + rel_arg = getattr(init_args, attr) + val = rel_arg.resolved if val is not None: - setattr( - self, - attr, - _orm_deannotate( - coercions.expect( - roles.ColumnArgumentRole, val, argname=attr - ) - ), + rel_arg.resolved = _orm_deannotate( + coercions.expect( + roles.ColumnArgumentRole, val, argname=attr + ) ) - if self.secondary is not None and _is_mapped_class(self.secondary): + secondary = init_args.secondary.resolved + if secondary is not None and _is_mapped_class(secondary): raise sa_exc.ArgumentError( "secondary argument %s passed to to relationship() %s must " "be a Table object or other FROM clause; can't send a mapped " "class directly as rows in 'secondary' are persisted " "independently of a class that is mapped " - "to that same table." % (self.secondary, self) + "to that same table." % (secondary, self) ) # ensure expressions in self.order_by, foreign_keys, # remote_side are all columns, not strings. - if self.order_by is not False and self.order_by is not None: + if ( + init_args.order_by.resolved is not False + and init_args.order_by.resolved is not None + ): self.order_by = tuple( coercions.expect( roles.ColumnArgumentRole, x, argname="order_by" ) - for x in util.to_list(self.order_by) + for x in util.to_list(init_args.order_by.resolved) ) + else: + self.order_by = False self._user_defined_foreign_keys = util.column_set( coercions.expect( roles.ColumnArgumentRole, x, argname="foreign_keys" ) - for x in util.to_column_set(self._user_defined_foreign_keys) + for x in util.to_column_set(init_args.foreign_keys.resolved) ) self.remote_side = util.column_set( coercions.expect( roles.ColumnArgumentRole, x, argname="remote_side" ) - for x in util.to_column_set(self.remote_side) + for x in util.to_column_set(init_args.remote_side.resolved) ) def declarative_scan( - self, registry, cls, key, annotation, is_dataclass_field - ): + self, + registry: _RegistryType, + cls: Type[Any], + key: str, + annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: argument = _extract_mapped_subtype( annotation, cls, @@ -1502,17 +1716,19 @@ class Relationship( if hasattr(argument, "__origin__"): - collection_class = argument.__origin__ + collection_class = argument.__origin__ # type: ignore if issubclass(collection_class, abc.Collection): if self.collection_class is None: self.collection_class = collection_class else: self.uselist = False - if argument.__args__: - if issubclass(argument.__origin__, typing.Mapping): - type_arg = argument.__args__[1] + if argument.__args__: # type: ignore + if issubclass( + argument.__origin__, typing.Mapping # type: ignore + ): + type_arg = argument.__args__[1] # type: ignore else: - type_arg = argument.__args__[0] + type_arg = argument.__args__[0] # type: ignore if hasattr(type_arg, "__forward_arg__"): str_argument = type_arg.__forward_arg__ argument = str_argument @@ -1523,12 +1739,12 @@ class Relationship( f"Generic alias {argument} requires an argument" ) elif hasattr(argument, "__forward_arg__"): - argument = argument.__forward_arg__ + argument = argument.__forward_arg__ # type: ignore self.argument = argument @util.preload_module("sqlalchemy.orm.mapper") - def _setup_entity(self, __argument=None): + def _setup_entity(self, __argument: Any = None) -> None: if "entity" in self.__dict__: return @@ -1539,42 +1755,51 @@ class Relationship( else: argument = self.argument + resolved_argument: _ExternalEntityType[Any] + if isinstance(argument, str): - argument = self._clsregistry_resolve_name(argument)() + # we might want to cleanup clsregistry API to make this + # more straightforward + resolved_argument = cast( + "_ExternalEntityType[Any]", + self._clsregistry_resolve_name(argument)(), + ) elif callable(argument) and not isinstance( argument, (type, mapperlib.Mapper) ): - argument = argument() + resolved_argument = argument() else: - argument = argument + resolved_argument = argument - if isinstance(argument, type): - entity = class_mapper(argument, configure=False) + entity: _InternalEntityType[Any] + + if isinstance(resolved_argument, type): + entity = class_mapper(resolved_argument, configure=False) else: try: - entity = inspect(argument) + entity = inspect(resolved_argument) except sa_exc.NoInspectionAvailable: - entity = None + entity = None # type: ignore if not hasattr(entity, "mapper"): raise sa_exc.ArgumentError( "relationship '%s' expects " "a class or a mapper argument (received: %s)" - % (self.key, type(argument)) + % (self.key, type(resolved_argument)) ) self.entity = entity # type: ignore self.target = self.entity.persist_selectable - def _setup_join_conditions(self): + def _setup_join_conditions(self) -> None: self._join_condition = jc = JoinCondition( parent_persist_selectable=self.parent.persist_selectable, child_persist_selectable=self.entity.persist_selectable, parent_local_selectable=self.parent.local_table, child_local_selectable=self.entity.local_table, - primaryjoin=self.primaryjoin, - secondary=self.secondary, - secondaryjoin=self.secondaryjoin, + primaryjoin=self._init_args.primaryjoin.resolved, + secondary=self._init_args.secondary.resolved, + secondaryjoin=self._init_args.secondaryjoin.resolved, parent_equivalents=self.parent._equivalent_columns, child_equivalents=self.mapper._equivalent_columns, consider_as_foreign_keys=self._user_defined_foreign_keys, @@ -1587,6 +1812,7 @@ class Relationship( ) self.primaryjoin = jc.primaryjoin self.secondaryjoin = jc.secondaryjoin + self.secondary = jc.secondary self.direction = jc.direction self.local_remote_pairs = jc.local_remote_pairs self.remote_side = jc.remote_columns @@ -1596,21 +1822,30 @@ class Relationship( self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs @property - def _clsregistry_resolve_arg(self): + def _clsregistry_resolve_arg( + self, + ) -> Callable[[str, bool], _class_resolver]: return self._clsregistry_resolvers[1] @property - def _clsregistry_resolve_name(self): + def _clsregistry_resolve_name( + self, + ) -> Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]]: return self._clsregistry_resolvers[0] @util.memoized_property @util.preload_module("sqlalchemy.orm.clsregistry") - def _clsregistry_resolvers(self): + def _clsregistry_resolvers( + self, + ) -> Tuple[ + Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], + Callable[[str, bool], _class_resolver], + ]: _resolver = util.preloaded.orm_clsregistry._resolver return _resolver(self.parent.class_, self) - def _check_conflicts(self): + def _check_conflicts(self) -> None: """Test that this relationship is legal, warn about inheritance conflicts.""" if self.parent.non_primary and not class_mapper( @@ -1637,10 +1872,10 @@ class Relationship( return self._cascade @cascade.setter - def cascade(self, cascade: Union[str, CascadeOptions]): + def cascade(self, cascade: Union[str, CascadeOptions]) -> None: self._set_cascade(cascade) - def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]): + def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]) -> None: cascade = CascadeOptions(cascade_arg) if self.viewonly: @@ -1655,7 +1890,7 @@ class Relationship( if self._dependency_processor: self._dependency_processor.cascade = cascade - def _check_cascade_settings(self, cascade): + def _check_cascade_settings(self, cascade: CascadeOptions) -> None: if ( cascade.delete_orphan and not self.single_parent @@ -1699,7 +1934,7 @@ class Relationship( (self.key, self.parent.class_) ) - def _persists_for(self, mapper): + def _persists_for(self, mapper: Mapper[Any]) -> bool: """Return True if this property will persist values on behalf of the given mapper. @@ -1710,16 +1945,15 @@ class Relationship( and mapper.relationships[self.key] is self ) - def _columns_are_mapped(self, *cols): + def _columns_are_mapped(self, *cols: ColumnElement[Any]) -> bool: """Return True if all columns in the given collection are mapped by the tables referenced by this :class:`.Relationship`. """ + + secondary = self._init_args.secondary.resolved for c in cols: - if ( - self.secondary is not None - and self.secondary.c.contains_column(c) - ): + if secondary is not None and secondary.c.contains_column(c): continue if not self.parent.persist_selectable.c.contains_column( c @@ -1727,13 +1961,14 @@ class Relationship( return False return True - def _generate_backref(self): + def _generate_backref(self) -> None: """Interpret the 'backref' instruction to create a :func:`_orm.relationship` complementary to this one.""" if self.parent.non_primary: return if self.backref is not None and not self.back_populates: + kwargs: Dict[str, Any] if isinstance(self.backref, str): backref_key, kwargs = self.backref, {} else: @@ -1805,7 +2040,7 @@ class Relationship( self._add_reverse_property(self.back_populates) @util.preload_module("sqlalchemy.orm.dependency") - def _post_init(self): + def _post_init(self) -> None: dependency = util.preloaded.orm_dependency if self.uselist is None: @@ -1816,7 +2051,7 @@ class Relationship( )(self) @util.memoized_property - def _use_get(self): + def _use_get(self) -> bool: """memoize the 'use_get' attribute of this RelationshipLoader's lazyloader.""" @@ -1824,18 +2059,25 @@ class Relationship( return strategy.use_get @util.memoized_property - def _is_self_referential(self): + def _is_self_referential(self) -> bool: return self.mapper.common_parent(self.parent) def _create_joins( self, - source_polymorphic=False, - source_selectable=None, - dest_selectable=None, - of_type_entity=None, - alias_secondary=False, - extra_criteria=(), - ): + source_polymorphic: bool = False, + source_selectable: Optional[FromClause] = None, + dest_selectable: Optional[FromClause] = None, + of_type_entity: Optional[_InternalEntityType[Any]] = None, + alias_secondary: bool = False, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ) -> Tuple[ + ColumnElement[bool], + Optional[ColumnElement[bool]], + FromClause, + FromClause, + Optional[FromClause], + Optional[ClauseAdapter], + ]: aliased = False @@ -1905,38 +2147,56 @@ class Relationship( ) -def _annotate_columns(element, annotations): - def clone(elem): +def _annotate_columns(element: _CE, annotations: _AnnotationDict) -> _CE: + def clone(elem: _CE) -> _CE: if isinstance(elem, expression.ColumnClause): - elem = elem._annotate(annotations.copy()) + elem = elem._annotate(annotations.copy()) # type: ignore elem._copy_internals(clone=clone) return elem if element is not None: element = clone(element) - clone = None # remove gc cycles + clone = None # type: ignore # remove gc cycles return element class JoinCondition: + + primaryjoin_initial: Optional[ColumnElement[bool]] + primaryjoin: ColumnElement[bool] + secondaryjoin: Optional[ColumnElement[bool]] + secondary: Optional[FromClause] + prop: Relationship[Any] + + synchronize_pairs: _ColumnPairs + secondary_synchronize_pairs: _ColumnPairs + direction: RelationshipDirection + + parent_persist_selectable: FromClause + child_persist_selectable: FromClause + parent_local_selectable: FromClause + child_local_selectable: FromClause + + _local_remote_pairs: Optional[_ColumnPairs] + def __init__( self, - parent_persist_selectable, - child_persist_selectable, - parent_local_selectable, - child_local_selectable, - primaryjoin=None, - secondary=None, - secondaryjoin=None, - parent_equivalents=None, - child_equivalents=None, - consider_as_foreign_keys=None, - local_remote_pairs=None, - remote_side=None, - self_referential=False, - prop=None, - support_sync=True, - can_be_synced_fn=lambda *c: True, + parent_persist_selectable: FromClause, + child_persist_selectable: FromClause, + parent_local_selectable: FromClause, + child_local_selectable: FromClause, + primaryjoin: Optional[ColumnElement[bool]] = None, + secondary: Optional[FromClause] = None, + secondaryjoin: Optional[ColumnElement[bool]] = None, + parent_equivalents: Optional[_EquivalentColumnMap] = None, + child_equivalents: Optional[_EquivalentColumnMap] = None, + consider_as_foreign_keys: Any = None, + local_remote_pairs: Optional[_ColumnPairs] = None, + remote_side: Any = None, + self_referential: Any = False, + prop: Optional[Relationship[Any]] = None, + support_sync: bool = True, + can_be_synced_fn: Callable[..., bool] = lambda *c: True, ): self.parent_persist_selectable = parent_persist_selectable self.parent_local_selectable = parent_local_selectable @@ -1944,7 +2204,7 @@ class JoinCondition: self.child_local_selectable = child_local_selectable self.parent_equivalents = parent_equivalents self.child_equivalents = child_equivalents - self.primaryjoin = primaryjoin + self.primaryjoin_initial = primaryjoin self.secondaryjoin = secondaryjoin self.secondary = secondary self.consider_as_foreign_keys = consider_as_foreign_keys @@ -1954,7 +2214,10 @@ class JoinCondition: self.self_referential = self_referential self.support_sync = support_sync self.can_be_synced_fn = can_be_synced_fn + self._determine_joins() + assert self.primaryjoin is not None + self._sanitize_joins() self._annotate_fks() self._annotate_remote() @@ -1968,7 +2231,7 @@ class JoinCondition: self._check_remote_side() self._log_joins() - def _log_joins(self): + def _log_joins(self) -> None: if self.prop is None: return log = self.prop.logger @@ -2008,7 +2271,7 @@ class JoinCondition: ) log.info("%s relationship direction %s", self.prop, self.direction) - def _sanitize_joins(self): + def _sanitize_joins(self) -> None: """remove the parententity annotation from our join conditions which can leak in here based on some declarative patterns and maybe others. @@ -2026,7 +2289,7 @@ class JoinCondition: self.secondaryjoin, values=("parententity", "proxy_key") ) - def _determine_joins(self): + def _determine_joins(self) -> None: """Determine the 'primaryjoin' and 'secondaryjoin' attributes, if not passed to the constructor already. @@ -2056,21 +2319,25 @@ class JoinCondition: a_subset=self.child_local_selectable, consider_as_foreign_keys=consider_as_foreign_keys, ) - if self.primaryjoin is None: + if self.primaryjoin_initial is None: self.primaryjoin = join_condition( self.parent_persist_selectable, self.secondary, a_subset=self.parent_local_selectable, consider_as_foreign_keys=consider_as_foreign_keys, ) + else: + self.primaryjoin = self.primaryjoin_initial else: - if self.primaryjoin is None: + if self.primaryjoin_initial is None: self.primaryjoin = join_condition( self.parent_persist_selectable, self.child_persist_selectable, a_subset=self.parent_local_selectable, consider_as_foreign_keys=consider_as_foreign_keys, ) + else: + self.primaryjoin = self.primaryjoin_initial except sa_exc.NoForeignKeysError as nfe: if self.secondary is not None: raise sa_exc.NoForeignKeysError( @@ -2118,15 +2385,16 @@ class JoinCondition: ) from afe @property - def primaryjoin_minus_local(self): + def primaryjoin_minus_local(self) -> ColumnElement[bool]: return _deep_deannotate(self.primaryjoin, values=("local", "remote")) @property - def secondaryjoin_minus_local(self): + def secondaryjoin_minus_local(self) -> ColumnElement[bool]: + assert self.secondaryjoin is not None return _deep_deannotate(self.secondaryjoin, values=("local", "remote")) @util.memoized_property - def primaryjoin_reverse_remote(self): + def primaryjoin_reverse_remote(self) -> ColumnElement[bool]: """Return the primaryjoin condition suitable for the "reverse" direction. @@ -2138,7 +2406,7 @@ class JoinCondition: """ if self._has_remote_annotations: - def replace(element): + def replace(element: _CE, **kw: Any) -> Optional[_CE]: if "remote" in element._annotations: v = dict(element._annotations) del v["remote"] @@ -2150,6 +2418,8 @@ class JoinCondition: v["remote"] = True return element._with_annotations(v) + return None + return visitors.replacement_traverse(self.primaryjoin, {}, replace) else: if self._has_foreign_annotations: @@ -2160,7 +2430,7 @@ class JoinCondition: else: return _deep_deannotate(self.primaryjoin) - def _has_annotation(self, clause, annotation): + def _has_annotation(self, clause: ClauseElement, annotation: str) -> bool: for col in visitors.iterate(clause, {}): if annotation in col._annotations: return True @@ -2168,14 +2438,14 @@ class JoinCondition: return False @util.memoized_property - def _has_foreign_annotations(self): + def _has_foreign_annotations(self) -> bool: return self._has_annotation(self.primaryjoin, "foreign") @util.memoized_property - def _has_remote_annotations(self): + def _has_remote_annotations(self) -> bool: return self._has_annotation(self.primaryjoin, "remote") - def _annotate_fks(self): + def _annotate_fks(self) -> None: """Annotate the primaryjoin and secondaryjoin structures with 'foreign' annotations marking columns considered as foreign. @@ -2189,10 +2459,11 @@ class JoinCondition: else: self._annotate_present_fks() - def _annotate_from_fk_list(self): - def check_fk(col): - if col in self.consider_as_foreign_keys: - return col._annotate({"foreign": True}) + def _annotate_from_fk_list(self) -> None: + def check_fk(element: _CE, **kw: Any) -> Optional[_CE]: + if element in self.consider_as_foreign_keys: + return element._annotate({"foreign": True}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, check_fk @@ -2202,13 +2473,15 @@ class JoinCondition: self.secondaryjoin, {}, check_fk ) - def _annotate_present_fks(self): + def _annotate_present_fks(self) -> None: if self.secondary is not None: secondarycols = util.column_set(self.secondary.c) else: secondarycols = set() - def is_foreign(a, b): + def is_foreign( + a: ColumnElement[Any], b: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: if isinstance(a, schema.Column) and isinstance(b, schema.Column): if a.references(b): return a @@ -2221,7 +2494,9 @@ class JoinCondition: elif b in secondarycols and a not in secondarycols: return b - def visit_binary(binary): + return None + + def visit_binary(binary: BinaryExpression[Any]) -> None: if not isinstance( binary.left, sql.ColumnElement ) or not isinstance(binary.right, sql.ColumnElement): @@ -2248,16 +2523,17 @@ class JoinCondition: self.secondaryjoin, {}, {"binary": visit_binary} ) - def _refers_to_parent_table(self): + def _refers_to_parent_table(self) -> bool: """Return True if the join condition contains column comparisons where both columns are in both tables. """ pt = self.parent_persist_selectable mt = self.child_persist_selectable - result = [False] + result = False - def visit_binary(binary): + def visit_binary(binary: BinaryExpression[Any]) -> None: + nonlocal result c, f = binary.left, binary.right if ( isinstance(c, expression.ColumnClause) @@ -2267,19 +2543,19 @@ class JoinCondition: and mt.is_derived_from(c.table) and mt.is_derived_from(f.table) ): - result[0] = True + result = True visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary}) - return result[0] + return result - def _tables_overlap(self): + def _tables_overlap(self) -> bool: """Return True if parent/child tables have some overlap.""" return selectables_overlap( self.parent_persist_selectable, self.child_persist_selectable ) - def _annotate_remote(self): + def _annotate_remote(self) -> None: """Annotate the primaryjoin and secondaryjoin structures with 'remote' annotations marking columns considered as part of the 'remote' side. @@ -2301,30 +2577,38 @@ class JoinCondition: else: self._annotate_remote_distinct_selectables() - def _annotate_remote_secondary(self): + def _annotate_remote_secondary(self) -> None: """annotate 'remote' in primaryjoin, secondaryjoin when 'secondary' is present. """ - def repl(element): - if self.secondary.c.contains_column(element): + assert self.secondary is not None + fixed_secondary = self.secondary + + def repl(element: _CE, **kw: Any) -> Optional[_CE]: + if fixed_secondary.c.contains_column(element): return element._annotate({"remote": True}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, repl ) + + assert self.secondaryjoin is not None self.secondaryjoin = visitors.replacement_traverse( self.secondaryjoin, {}, repl ) - def _annotate_selfref(self, fn, remote_side_given): + def _annotate_selfref( + self, fn: Callable[[ColumnElement[Any]], bool], remote_side_given: bool + ) -> None: """annotate 'remote' in primaryjoin, secondaryjoin when the relationship is detected as self-referential. """ - def visit_binary(binary): + def visit_binary(binary: BinaryExpression[Any]) -> None: equated = binary.left.compare(binary.right) if isinstance(binary.left, expression.ColumnClause) and isinstance( binary.right, expression.ColumnClause @@ -2341,7 +2625,7 @@ class JoinCondition: self.primaryjoin, {}, {"binary": visit_binary} ) - def _annotate_remote_from_args(self): + def _annotate_remote_from_args(self) -> None: """annotate 'remote' in primaryjoin, secondaryjoin when the 'remote_side' or '_local_remote_pairs' arguments are used. @@ -2363,17 +2647,18 @@ class JoinCondition: self._annotate_selfref(lambda col: col in remote_side, True) else: - def repl(element): + def repl(element: _CE, **kw: Any) -> Optional[_CE]: # use set() to avoid generating ``__eq__()`` expressions # against each element if element in set(remote_side): return element._annotate({"remote": True}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, repl ) - def _annotate_remote_with_overlap(self): + def _annotate_remote_with_overlap(self) -> None: """annotate 'remote' in primaryjoin, secondaryjoin when the parent/child tables have some set of tables in common, though is not a fully self-referential @@ -2381,7 +2666,7 @@ class JoinCondition: """ - def visit_binary(binary): + def visit_binary(binary: BinaryExpression[Any]) -> None: binary.left, binary.right = proc_left_right( binary.left, binary.right ) @@ -2393,7 +2678,9 @@ class JoinCondition: self.prop is not None and self.prop.mapper is not self.prop.parent ) - def proc_left_right(left, right): + def proc_left_right( + left: ColumnElement[Any], right: ColumnElement[Any] + ) -> Tuple[ColumnElement[Any], ColumnElement[Any]]: if isinstance(left, expression.ColumnClause) and isinstance( right, expression.ColumnClause ): @@ -2420,32 +2707,33 @@ class JoinCondition: self.primaryjoin, {}, {"binary": visit_binary} ) - def _annotate_remote_distinct_selectables(self): + def _annotate_remote_distinct_selectables(self) -> None: """annotate 'remote' in primaryjoin, secondaryjoin when the parent/child tables are entirely separate. """ - def repl(element): + def repl(element: _CE, **kw: Any) -> Optional[_CE]: if self.child_persist_selectable.c.contains_column(element) and ( not self.parent_local_selectable.c.contains_column(element) or self.child_local_selectable.c.contains_column(element) ): return element._annotate({"remote": True}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, repl ) - def _warn_non_column_elements(self): + def _warn_non_column_elements(self) -> None: util.warn( "Non-simple column elements in primary " "join condition for property %s - consider using " "remote() annotations to mark the remote side." % self.prop ) - def _annotate_local(self): + def _annotate_local(self) -> None: """Annotate the primaryjoin and secondaryjoin structures with 'local' annotations. @@ -2466,29 +2754,31 @@ class JoinCondition: else: local_side = util.column_set(self.parent_persist_selectable.c) - def locals_(elem): - if "remote" not in elem._annotations and elem in local_side: - return elem._annotate({"local": True}) + def locals_(element: _CE, **kw: Any) -> Optional[_CE]: + if "remote" not in element._annotations and element in local_side: + return element._annotate({"local": True}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, locals_ ) - def _annotate_parentmapper(self): + def _annotate_parentmapper(self) -> None: if self.prop is None: return - def parentmappers_(elem): - if "remote" in elem._annotations: - return elem._annotate({"parentmapper": self.prop.mapper}) - elif "local" in elem._annotations: - return elem._annotate({"parentmapper": self.prop.parent}) + def parentmappers_(element: _CE, **kw: Any) -> Optional[_CE]: + if "remote" in element._annotations: + return element._annotate({"parentmapper": self.prop.mapper}) + elif "local" in element._annotations: + return element._annotate({"parentmapper": self.prop.parent}) + return None self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, parentmappers_ ) - def _check_remote_side(self): + def _check_remote_side(self) -> None: if not self.local_remote_pairs: raise sa_exc.ArgumentError( "Relationship %s could " @@ -2501,7 +2791,9 @@ class JoinCondition: "the relationship." % (self.prop,) ) - def _check_foreign_cols(self, join_condition, primary): + def _check_foreign_cols( + self, join_condition: ColumnElement[bool], primary: bool + ) -> None: """Check the foreign key columns collected and emit error messages.""" @@ -2567,7 +2859,7 @@ class JoinCondition: ) raise sa_exc.ArgumentError(err) - def _determine_direction(self): + def _determine_direction(self) -> None: """Determine if this relationship is one to many, many to one, many to many. @@ -2651,7 +2943,9 @@ class JoinCondition: "nor the child's mapped tables" % self.prop ) - def _deannotate_pairs(self, collection): + def _deannotate_pairs( + self, collection: _ColumnPairIterable + ) -> _MutableColumnPairs: """provide deannotation for the various lists of pairs, so that using them in hashes doesn't incur high-overhead __eq__() comparisons against @@ -2660,13 +2954,22 @@ class JoinCondition: """ return [(x._deannotate(), y._deannotate()) for x, y in collection] - def _setup_pairs(self): - sync_pairs = [] - lrp = util.OrderedSet([]) - secondary_sync_pairs = [] - - def go(joincond, collection): - def visit_binary(binary, left, right): + def _setup_pairs(self) -> None: + sync_pairs: _MutableColumnPairs = [] + lrp: util.OrderedSet[ + Tuple[ColumnElement[Any], ColumnElement[Any]] + ] = util.OrderedSet([]) + secondary_sync_pairs: _MutableColumnPairs = [] + + def go( + joincond: ColumnElement[bool], + collection: _MutableColumnPairs, + ) -> None: + def visit_binary( + binary: BinaryExpression[Any], + left: ColumnElement[Any], + right: ColumnElement[Any], + ) -> None: if ( "remote" in right._annotations and "remote" not in left._annotations @@ -2703,9 +3006,12 @@ class JoinCondition: secondary_sync_pairs ) - _track_overlapping_sync_targets = weakref.WeakKeyDictionary() + _track_overlapping_sync_targets: weakref.WeakKeyDictionary[ + ColumnElement[Any], + weakref.WeakKeyDictionary[Relationship[Any], ColumnElement[Any]], + ] = weakref.WeakKeyDictionary() - def _warn_for_conflicting_sync_targets(self): + def _warn_for_conflicting_sync_targets(self) -> None: if not self.support_sync: return @@ -2793,18 +3099,20 @@ class JoinCondition: self._track_overlapping_sync_targets[to_][self.prop] = from_ @util.memoized_property - def remote_columns(self): + def remote_columns(self) -> Set[ColumnElement[Any]]: return self._gather_join_annotations("remote") @util.memoized_property - def local_columns(self): + def local_columns(self) -> Set[ColumnElement[Any]]: return self._gather_join_annotations("local") @util.memoized_property - def foreign_key_columns(self): + def foreign_key_columns(self) -> Set[ColumnElement[Any]]: return self._gather_join_annotations("foreign") - def _gather_join_annotations(self, annotation): + def _gather_join_annotations( + self, annotation: str + ) -> Set[ColumnElement[Any]]: s = set( self._gather_columns_with_annotation(self.primaryjoin, annotation) ) @@ -2816,24 +3124,32 @@ class JoinCondition: ) return {x._deannotate() for x in s} - def _gather_columns_with_annotation(self, clause, *annotation): - annotation = set(annotation) + def _gather_columns_with_annotation( + self, clause: ColumnElement[Any], *annotation: Iterable[str] + ) -> Set[ColumnElement[Any]]: + annotation_set = set(annotation) return set( [ - col + cast(ColumnElement[Any], col) for col in visitors.iterate(clause, {}) - if annotation.issubset(col._annotations) + if annotation_set.issubset(col._annotations) ] ) def join_targets( self, - source_selectable, - dest_selectable, - aliased, - single_crit=None, - extra_criteria=(), - ): + source_selectable: Optional[FromClause], + dest_selectable: FromClause, + aliased: bool, + single_crit: Optional[ColumnElement[bool]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ) -> Tuple[ + ColumnElement[bool], + Optional[ColumnElement[bool]], + Optional[FromClause], + Optional[ClauseAdapter], + FromClause, + ]: """Given a source and destination selectable, create a join between them. @@ -2923,9 +3239,15 @@ class JoinCondition: dest_selectable, ) - def create_lazy_clause(self, reverse_direction=False): - binds = util.column_dict() - equated_columns = util.column_dict() + def create_lazy_clause( + self, reverse_direction: bool = False + ) -> Tuple[ + ColumnElement[bool], + Dict[str, ColumnElement[Any]], + Dict[ColumnElement[Any], ColumnElement[Any]], + ]: + binds: Dict[ColumnElement[Any], BindParameter[Any]] = {} + equated_columns: Dict[ColumnElement[Any], ColumnElement[Any]] = {} has_secondary = self.secondaryjoin is not None @@ -2941,21 +3263,23 @@ class JoinCondition: for l, r in self.local_remote_pairs: equated_columns[l] = r - def col_to_bind(col): + def col_to_bind( + element: ColumnElement[Any], **kw: Any + ) -> Optional[BindParameter[Any]]: if ( - (not reverse_direction and "local" in col._annotations) + (not reverse_direction and "local" in element._annotations) or reverse_direction and ( - (has_secondary and col in lookup) - or (not has_secondary and "remote" in col._annotations) + (has_secondary and element in lookup) + or (not has_secondary and "remote" in element._annotations) ) ): - if col not in binds: - binds[col] = sql.bindparam( - None, None, type_=col.type, unique=True + if element not in binds: + binds[element] = sql.bindparam( + None, None, type_=element.type, unique=True ) - return binds[col] + return binds[element] return None lazywhere = self.primaryjoin @@ -2982,8 +3306,8 @@ class _ColInAnnotations: __slots__ = ("name",) - def __init__(self, name): + def __init__(self, name: str): self.name = name - def __call__(self, c): + def __call__(self, c: ClauseElement) -> bool: return self.name in c._annotations |