summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/relationships.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/relationships.py')
-rw-r--r--lib/sqlalchemy/orm/relationships.py1004
1 files changed, 664 insertions, 340 deletions
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 8273775ae..1186f0f54 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -17,13 +17,23 @@ from __future__ import annotations
import collections
from collections import abc
+import dataclasses
import re
import typing
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Collection
from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -32,14 +42,19 @@ import weakref
from . import attributes
from . import strategy_options
+from ._typing import insp_is_aliased_class
+from ._typing import is_has_collection_adapter
from .base import _is_mapped_class
from .base import class_mapper
+from .base import LoaderCallableStatus
+from .base import PassiveFlag
from .base import state_str
from .interfaces import _IntrospectsAnnotations
from .interfaces import MANYTOMANY
from .interfaces import MANYTOONE
from .interfaces import ONETOMANY
from .interfaces import PropComparator
+from .interfaces import RelationshipDirection
from .interfaces import StrategizedProperty
from .util import _extract_mapped_subtype
from .util import _orm_annotate
@@ -60,6 +75,7 @@ from ..sql import visitors
from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _HasClauseElement
from ..sql.elements import ColumnClause
+from ..sql.elements import ColumnElement
from ..sql.util import _deep_deannotate
from ..sql.util import _shallow_annotate
from ..sql.util import adapt_criterion_to_null
@@ -71,15 +87,42 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from ._typing import _EntityType
+ from ._typing import _ExternalEntityType
+ from ._typing import _IdentityKeyType
+ from ._typing import _InstanceDict
from ._typing import _InternalEntityType
+ from ._typing import _O
+ from ._typing import _RegistryType
+ from .clsregistry import _class_resolver
+ from .clsregistry import _ModNS
+ from .dependency import DependencyProcessor
from .mapper import Mapper
+ from .query import Query
+ from .session import Session
+ from .state import InstanceState
+ from .strategies import LazyLoader
from .util import AliasedClass
from .util import AliasedInsp
- from ..sql.elements import ColumnElement
+ from ..sql._typing import _CoreAdapterProto
+ from ..sql._typing import _EquivalentColumnMap
+ from ..sql._typing import _InfoType
+ from ..sql.annotation import _AnnotationDict
+ from ..sql.elements import BinaryExpression
+ from ..sql.elements import BindParameter
+ from ..sql.elements import ClauseElement
+ from ..sql.schema import Table
+ from ..sql.selectable import FromClause
+ from ..util.typing import _AnnotationScanType
+ from ..util.typing import RODescriptorReference
_T = TypeVar("_T", bound=Any)
+_T1 = TypeVar("_T1", bound=Any)
+_T2 = TypeVar("_T2", bound=Any)
+
_PT = TypeVar("_PT", bound=Any)
+_PT2 = TypeVar("_PT2", bound=Any)
+
_RelationshipArgumentType = Union[
str,
@@ -111,7 +154,10 @@ _RelationshipJoinConditionArgument = Union[
str, _ColumnExpressionArgument[bool]
]
_ORMOrderByArgument = Union[
- Literal[False], str, _ColumnExpressionArgument[Any]
+ Literal[False],
+ str,
+ _ColumnExpressionArgument[Any],
+ Iterable[Union[str, _ColumnExpressionArgument[Any]]],
]
_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
_ORMColCollectionArgument = Union[
@@ -120,7 +166,19 @@ _ORMColCollectionArgument = Union[
]
-def remote(expr):
+_CEA = TypeVar("_CEA", bound=_ColumnExpressionArgument[Any])
+
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
+
+
+_ColumnPairIterable = Iterable[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+_ColumnPairs = Sequence[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+_MutableColumnPairs = List[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+
+def remote(expr: _CEA) -> _CEA:
"""Annotate a portion of a primaryjoin expression
with a 'remote' annotation.
@@ -134,12 +192,12 @@ def remote(expr):
:func:`.foreign`
"""
- return _annotate_columns(
+ return _annotate_columns( # type: ignore
coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True}
)
-def foreign(expr):
+def foreign(expr: _CEA) -> _CEA:
"""Annotate a portion of a primaryjoin expression
with a 'foreign' annotation.
@@ -154,11 +212,71 @@ def foreign(expr):
"""
- return _annotate_columns(
+ return _annotate_columns( # type: ignore
coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True}
)
+@dataclasses.dataclass
+class _RelationshipArg(Generic[_T1, _T2]):
+ """stores a user-defined parameter value that must be resolved and
+ parsed later at mapper configuration time.
+
+ """
+
+ __slots__ = "name", "argument", "resolved"
+ name: str
+ argument: _T1
+ resolved: Optional[_T2]
+
+ def _is_populated(self) -> bool:
+ return self.argument is not None
+
+ def _resolve_against_registry(
+ self, clsregistry_resolver: Callable[[str, bool], _class_resolver]
+ ) -> None:
+ attr_value = self.argument
+
+ if isinstance(attr_value, str):
+ self.resolved = clsregistry_resolver(
+ attr_value, self.name == "secondary"
+ )()
+ elif callable(attr_value) and not _is_mapped_class(attr_value):
+ self.resolved = attr_value()
+ else:
+ self.resolved = attr_value
+
+
+class _RelationshipArgs(NamedTuple):
+ """stores user-passed parameters that are resolved at mapper configuration
+ time.
+
+ """
+
+ secondary: _RelationshipArg[
+ Optional[Union[FromClause, str]],
+ Optional[FromClause],
+ ]
+ primaryjoin: _RelationshipArg[
+ Optional[_RelationshipJoinConditionArgument],
+ Optional[ColumnElement[Any]],
+ ]
+ secondaryjoin: _RelationshipArg[
+ Optional[_RelationshipJoinConditionArgument],
+ Optional[ColumnElement[Any]],
+ ]
+ order_by: _RelationshipArg[
+ _ORMOrderByArgument,
+ Union[Literal[None, False], Tuple[ColumnElement[Any], ...]],
+ ]
+ foreign_keys: _RelationshipArg[
+ Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
+ ]
+ remote_side: _RelationshipArg[
+ Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
+ ]
+
+
@log.class_logger
class Relationship(
_IntrospectsAnnotations, StrategizedProperty[_T], log.Identified
@@ -184,6 +302,10 @@ class Relationship(
_links_to_entity = True
_is_relationship = True
+ _overlaps: Sequence[str]
+
+ _lazy_strategy: LazyLoader
+
_persistence_only = dict(
passive_deletes=False,
passive_updates=True,
@@ -192,56 +314,87 @@ class Relationship(
cascade_backrefs=False,
)
- _dependency_processor = None
+ _dependency_processor: Optional[DependencyProcessor] = None
+
+ primaryjoin: ColumnElement[bool]
+ secondaryjoin: Optional[ColumnElement[bool]]
+ secondary: Optional[FromClause]
+ _join_condition: JoinCondition
+ order_by: Union[Literal[False], Tuple[ColumnElement[Any], ...]]
+
+ _user_defined_foreign_keys: Set[ColumnElement[Any]]
+ _calculated_foreign_keys: Set[ColumnElement[Any]]
+
+ remote_side: Set[ColumnElement[Any]]
+ local_columns: Set[ColumnElement[Any]]
+
+ synchronize_pairs: _ColumnPairs
+ secondary_synchronize_pairs: Optional[_ColumnPairs]
+
+ local_remote_pairs: Optional[_ColumnPairs]
+
+ direction: RelationshipDirection
+
+ _init_args: _RelationshipArgs
def __init__(
self,
argument: Optional[_RelationshipArgumentType[_T]] = None,
- secondary=None,
+ secondary: Optional[Union[FromClause, str]] = None,
*,
- uselist=None,
- collection_class=None,
- primaryjoin=None,
- secondaryjoin=None,
- back_populates=None,
- order_by=False,
- backref=None,
- cascade_backrefs=False,
- overlaps=None,
- post_update=False,
- cascade="save-update, merge",
- viewonly=False,
+ uselist: Optional[bool] = None,
+ collection_class: Optional[
+ Union[Type[Collection[Any]], Callable[[], Collection[Any]]]
+ ] = None,
+ primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+ secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+ back_populates: Optional[str] = None,
+ order_by: _ORMOrderByArgument = False,
+ backref: Optional[_ORMBackrefArgument] = None,
+ overlaps: Optional[str] = None,
+ post_update: bool = False,
+ cascade: str = "save-update, merge",
+ viewonly: bool = False,
lazy: _LazyLoadArgumentType = "select",
- passive_deletes=False,
- passive_updates=True,
- active_history=False,
- enable_typechecks=True,
- foreign_keys=None,
- remote_side=None,
- join_depth=None,
- comparator_factory=None,
- single_parent=False,
- innerjoin=False,
- distinct_target_key=None,
- load_on_pending=False,
- query_class=None,
- info=None,
- omit_join=None,
- sync_backref=None,
- doc=None,
- bake_queries=True,
- _local_remote_pairs=None,
- _legacy_inactive_history_style=False,
+ passive_deletes: Union[Literal["all"], bool] = False,
+ passive_updates: bool = True,
+ active_history: bool = False,
+ enable_typechecks: bool = True,
+ foreign_keys: Optional[_ORMColCollectionArgument] = None,
+ remote_side: Optional[_ORMColCollectionArgument] = None,
+ join_depth: Optional[int] = None,
+ comparator_factory: Optional[
+ Type[Relationship.Comparator[Any]]
+ ] = None,
+ single_parent: bool = False,
+ innerjoin: bool = False,
+ distinct_target_key: Optional[bool] = None,
+ load_on_pending: bool = False,
+ query_class: Optional[Type[Query[Any]]] = None,
+ info: Optional[_InfoType] = None,
+ omit_join: Literal[None, False] = None,
+ sync_backref: Optional[bool] = None,
+ doc: Optional[str] = None,
+ bake_queries: Literal[True] = True,
+ cascade_backrefs: Literal[False] = False,
+ _local_remote_pairs: Optional[_ColumnPairs] = None,
+ _legacy_inactive_history_style: bool = False,
):
super(Relationship, self).__init__()
self.uselist = uselist
self.argument = argument
- self.secondary = secondary
- self.primaryjoin = primaryjoin
- self.secondaryjoin = secondaryjoin
+
+ self._init_args = _RelationshipArgs(
+ _RelationshipArg("secondary", secondary, None),
+ _RelationshipArg("primaryjoin", primaryjoin, None),
+ _RelationshipArg("secondaryjoin", secondaryjoin, None),
+ _RelationshipArg("order_by", order_by, None),
+ _RelationshipArg("foreign_keys", foreign_keys, None),
+ _RelationshipArg("remote_side", remote_side, None),
+ )
+
self.post_update = post_update
- self.direction = None
self.viewonly = viewonly
if viewonly:
self._warn_for_persistence_only_flags(
@@ -258,7 +411,6 @@ class Relationship(
self.sync_backref = sync_backref
self.lazy = lazy
self.single_parent = single_parent
- self._user_defined_foreign_keys = foreign_keys
self.collection_class = collection_class
self.passive_deletes = passive_deletes
@@ -269,7 +421,6 @@ class Relationship(
)
self.passive_updates = passive_updates
- self.remote_side = remote_side
self.enable_typechecks = enable_typechecks
self.query_class = query_class
self.innerjoin = innerjoin
@@ -292,23 +443,22 @@ class Relationship(
self.local_remote_pairs = _local_remote_pairs
self.load_on_pending = load_on_pending
self.comparator_factory = comparator_factory or Relationship.Comparator
- self.comparator = self.comparator_factory(self, None)
util.set_creation_order(self)
if info is not None:
- self.info = info
+ self.info.update(info)
self.strategy_key = (("lazy", self.lazy),)
- self._reverse_property = set()
+ self._reverse_property: Set[Relationship[Any]] = set()
+
if overlaps:
- self._overlaps = set(re.split(r"\s*,\s*", overlaps))
+ self._overlaps = set(re.split(r"\s*,\s*", overlaps)) # type: ignore # noqa: E501
else:
self._overlaps = ()
- self.cascade = cascade
-
- self.order_by = order_by
+ # mypy ignoring the @property setter
+ self.cascade = cascade # type: ignore
self.back_populates = back_populates
@@ -322,7 +472,7 @@ class Relationship(
else:
self.backref = backref
- def _warn_for_persistence_only_flags(self, **kw):
+ def _warn_for_persistence_only_flags(self, **kw: Any) -> None:
for k, v in kw.items():
if v != self._persistence_only[k]:
# we are warning here rather than warn deprecated as this is a
@@ -340,7 +490,7 @@ class Relationship(
"in a future release." % (k,)
)
- def instrument_class(self, mapper):
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
attributes.register_descriptor(
mapper.class_,
self.key,
@@ -378,13 +528,16 @@ class Relationship(
"_extra_criteria",
)
+ prop: RODescriptorReference[Relationship[_PT]]
+ _of_type: Optional[_EntityType[_PT]]
+
def __init__(
self,
- prop,
- parentmapper,
- adapt_to_entity=None,
- of_type=None,
- extra_criteria=(),
+ prop: Relationship[_PT],
+ parentmapper: _InternalEntityType[Any],
+ adapt_to_entity: Optional[AliasedInsp[Any]] = None,
+ of_type: Optional[_EntityType[_PT]] = None,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
):
"""Construction of :class:`.Relationship.Comparator`
is internal to the ORM's attribute mechanics.
@@ -399,15 +552,17 @@ class Relationship(
self._of_type = None
self._extra_criteria = extra_criteria
- def adapt_to_entity(self, adapt_to_entity):
+ def adapt_to_entity(
+ self, adapt_to_entity: AliasedInsp[Any]
+ ) -> Relationship.Comparator[Any]:
return self.__class__(
- self.property,
+ self.prop,
self._parententity,
adapt_to_entity=adapt_to_entity,
of_type=self._of_type,
)
- entity: _InternalEntityType
+ entity: _InternalEntityType[_PT]
"""The target entity referred to by this
:class:`.Relationship.Comparator`.
@@ -419,7 +574,7 @@ class Relationship(
"""
- mapper: Mapper[Any]
+ mapper: Mapper[_PT]
"""The target :class:`_orm.Mapper` referred to by this
:class:`.Relationship.Comparator`.
@@ -428,22 +583,22 @@ class Relationship(
"""
- def _memoized_attr_entity(self) -> _InternalEntityType:
+ def _memoized_attr_entity(self) -> _InternalEntityType[_PT]:
if self._of_type:
- return inspect(self._of_type)
+ return inspect(self._of_type) # type: ignore
else:
return self.prop.entity
- def _memoized_attr_mapper(self) -> Mapper[Any]:
+ def _memoized_attr_mapper(self) -> Mapper[_PT]:
return self.entity.mapper
- def _source_selectable(self):
+ def _source_selectable(self) -> FromClause:
if self._adapt_to_entity:
return self._adapt_to_entity.selectable
else:
return self.property.parent._with_polymorphic_selectable
- def __clause_element__(self):
+ def __clause_element__(self) -> ColumnElement[bool]:
adapt_from = self._source_selectable()
if self._of_type:
of_type_entity = inspect(self._of_type)
@@ -457,7 +612,7 @@ class Relationship(
dest,
secondary,
target_adapter,
- ) = self.property._create_joins(
+ ) = self.prop._create_joins(
source_selectable=adapt_from,
source_polymorphic=True,
of_type_entity=of_type_entity,
@@ -469,7 +624,7 @@ class Relationship(
else:
return pj
- def of_type(self, cls):
+ def of_type(self, class_: _EntityType[_PT]) -> PropComparator[_PT]:
r"""Redefine this object in terms of a polymorphic subclass.
See :meth:`.PropComparator.of_type` for an example.
@@ -477,16 +632,16 @@ class Relationship(
"""
return Relationship.Comparator(
- self.property,
+ self.prop,
self._parententity,
adapt_to_entity=self._adapt_to_entity,
- of_type=cls,
+ of_type=class_,
extra_criteria=self._extra_criteria,
)
def and_(
self, *criteria: _ColumnExpressionArgument[bool]
- ) -> PropComparator[bool]:
+ ) -> PropComparator[Any]:
"""Add AND criteria.
See :meth:`.PropComparator.and_` for an example.
@@ -500,14 +655,14 @@ class Relationship(
)
return Relationship.Comparator(
- self.property,
+ self.prop,
self._parententity,
adapt_to_entity=self._adapt_to_entity,
of_type=self._of_type,
extra_criteria=self._extra_criteria + exprs,
)
- def in_(self, other):
+ def in_(self, other: Any) -> NoReturn:
"""Produce an IN clause - this is not implemented
for :func:`_orm.relationship`-based attributes at this time.
@@ -522,7 +677,7 @@ class Relationship(
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
"""Implement the ``==`` operator.
In a many-to-one context, such as::
@@ -559,7 +714,7 @@ class Relationship(
or many-to-many context produce a NOT EXISTS clause.
"""
- if isinstance(other, (util.NoneType, expression.Null)):
+ if other is None or isinstance(other, expression.Null):
if self.property.direction in [ONETOMANY, MANYTOMANY]:
return ~self._criterion_exists()
else:
@@ -585,8 +740,18 @@ class Relationship(
criterion: Optional[_ColumnExpressionArgument[bool]] = None,
**kwargs: Any,
) -> Exists:
+
+ where_criteria = (
+ coercions.expect(roles.WhereHavingRole, criterion)
+ if criterion is not None
+ else None
+ )
+
if getattr(self, "_of_type", None):
- info = inspect(self._of_type)
+ info: Optional[_InternalEntityType[Any]] = inspect(
+ self._of_type
+ )
+ assert info is not None
target_mapper, to_selectable, is_aliased_class = (
info.mapper,
info.selectable,
@@ -597,10 +762,10 @@ class Relationship(
single_crit = target_mapper._single_table_criterion
if single_crit is not None:
- if criterion is not None:
- criterion = single_crit & criterion
+ if where_criteria is not None:
+ where_criteria = single_crit & where_criteria
else:
- criterion = single_crit
+ where_criteria = single_crit
else:
is_aliased_class = False
to_selectable = None
@@ -624,10 +789,10 @@ class Relationship(
for k in kwargs:
crit = getattr(self.property.mapper.class_, k) == kwargs[k]
- if criterion is None:
- criterion = crit
+ if where_criteria is None:
+ where_criteria = crit
else:
- criterion = criterion & crit
+ where_criteria = where_criteria & crit
# annotate the *local* side of the join condition, in the case
# of pj + sj this is the full primaryjoin, in the case of just
@@ -638,24 +803,24 @@ class Relationship(
j = _orm_annotate(pj, exclude=self.property.remote_side)
if (
- criterion is not None
+ where_criteria is not None
and target_adapter
and not is_aliased_class
):
# limit this adapter to annotated only?
- criterion = target_adapter.traverse(criterion)
+ where_criteria = target_adapter.traverse(where_criteria)
# only have the "joined left side" of what we
# return be subject to Query adaption. The right
# side of it is used for an exists() subquery and
# should not correlate or otherwise reach out
# to anything in the enclosing query.
- if criterion is not None:
- criterion = criterion._annotate(
+ if where_criteria is not None:
+ where_criteria = where_criteria._annotate(
{"no_replacement_traverse": True}
)
- crit = j & sql.True_._ifnone(criterion)
+ crit = j & sql.True_._ifnone(where_criteria)
if secondary is not None:
ex = (
@@ -673,7 +838,11 @@ class Relationship(
)
return ex
- def any(self, criterion=None, **kwargs):
+ def any(
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
"""Produce an expression that tests a collection against
particular criterion, using EXISTS.
@@ -722,7 +891,11 @@ class Relationship(
return self._criterion_exists(criterion, **kwargs)
- def has(self, criterion=None, **kwargs):
+ def has(
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
"""Produce an expression that tests a scalar reference against
particular criterion, using EXISTS.
@@ -756,7 +929,9 @@ class Relationship(
)
return self._criterion_exists(criterion, **kwargs)
- def contains(self, other, **kwargs):
+ def contains(
+ self, other: _ColumnExpressionArgument[Any], **kwargs: Any
+ ) -> ColumnElement[bool]:
"""Return a simple expression that tests a collection for
containment of a particular item.
@@ -815,38 +990,45 @@ class Relationship(
kwargs may be ignored by this operator but are required for API
conformance.
"""
- if not self.property.uselist:
+ if not self.prop.uselist:
raise sa_exc.InvalidRequestError(
"'contains' not implemented for scalar "
"attributes. Use =="
)
- clause = self.property._optimized_compare(
+
+ clause = self.prop._optimized_compare(
other, adapt_source=self.adapter
)
- if self.property.secondaryjoin is not None:
+ if self.prop.secondaryjoin is not None:
clause.negation_clause = self.__negated_contains_or_equals(
other
)
return clause
- def __negated_contains_or_equals(self, other):
- if self.property.direction == MANYTOONE:
+ def __negated_contains_or_equals(
+ self, other: Any
+ ) -> ColumnElement[bool]:
+ if self.prop.direction == MANYTOONE:
state = attributes.instance_state(other)
- def state_bindparam(local_col, state, remote_col):
+ def state_bindparam(
+ local_col: ColumnElement[Any],
+ state: InstanceState[Any],
+ remote_col: ColumnElement[Any],
+ ) -> BindParameter[Any]:
dict_ = state.dict
return sql.bindparam(
local_col.key,
type_=local_col.type,
unique=True,
- callable_=self.property._get_attr_w_warn_on_none(
- self.property.mapper, state, dict_, remote_col
+ callable_=self.prop._get_attr_w_warn_on_none(
+ self.prop.mapper, state, dict_, remote_col
),
)
- def adapt(col):
+ def adapt(col: _CE) -> _CE:
if self.adapter:
return self.adapter(col)
else:
@@ -876,7 +1058,7 @@ class Relationship(
return ~self._criterion_exists(criterion)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
"""Implement the ``!=`` operator.
In a many-to-one context, such as::
@@ -915,7 +1097,7 @@ class Relationship(
or many-to-many context produce an EXISTS clause.
"""
- if isinstance(other, (util.NoneType, expression.Null)):
+ if other is None or isinstance(other, expression.Null):
if self.property.direction == MANYTOONE:
return _orm_annotate(
~self.property._optimized_compare(
@@ -934,12 +1116,10 @@ class Relationship(
else:
return _orm_annotate(self.__negated_contains_or_equals(other))
- def _memoized_attr_property(self):
+ def _memoized_attr_property(self) -> Relationship[_PT]:
self.prop.parent._check_configure()
return self.prop
- comparator: Comparator[_T]
-
def _with_parent(
self,
instance: object,
@@ -947,10 +1127,11 @@ class Relationship(
from_entity: Optional[_EntityType[Any]] = None,
) -> ColumnElement[bool]:
assert instance is not None
- adapt_source = None
+ adapt_source: Optional[_CoreAdapterProto] = None
if from_entity is not None:
- insp = inspect(from_entity)
- if insp.is_aliased_class:
+ insp: Optional[_InternalEntityType[Any]] = inspect(from_entity)
+ assert insp is not None
+ if insp_is_aliased_class(insp):
adapt_source = insp._adapter.adapt_clause
return self._optimized_compare(
instance,
@@ -961,11 +1142,11 @@ class Relationship(
def _optimized_compare(
self,
- state,
- value_is_parent=False,
- adapt_source=None,
- alias_secondary=True,
- ):
+ state: Any,
+ value_is_parent: bool = False,
+ adapt_source: Optional[_CoreAdapterProto] = None,
+ alias_secondary: bool = True,
+ ) -> ColumnElement[bool]:
if state is not None:
try:
state = inspect(state)
@@ -1005,7 +1186,7 @@ class Relationship(
dict_ = attributes.instance_dict(state.obj())
- def visit_bindparam(bindparam):
+ def visit_bindparam(bindparam: BindParameter[Any]) -> None:
if bindparam._identifying_key in bind_to_col:
bindparam.callable = self._get_attr_w_warn_on_none(
mapper,
@@ -1027,7 +1208,13 @@ class Relationship(
criterion = adapt_source(criterion)
return criterion
- def _get_attr_w_warn_on_none(self, mapper, state, dict_, column):
+ def _get_attr_w_warn_on_none(
+ self,
+ mapper: Mapper[Any],
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ column: ColumnElement[Any],
+ ) -> Callable[[], Any]:
"""Create the callable that is used in a many-to-one expression.
E.g.::
@@ -1077,9 +1264,14 @@ class Relationship(
# this feature was added explicitly for use in this method.
state._track_last_known_value(prop.key)
- def _go():
- last_known = to_return = state._last_known_values[prop.key]
- existing_is_available = last_known is not attributes.NO_VALUE
+ lkv_fixed = state._last_known_values
+
+ def _go() -> Any:
+ assert lkv_fixed is not None
+ last_known = to_return = lkv_fixed[prop.key]
+ existing_is_available = (
+ last_known is not LoaderCallableStatus.NO_VALUE
+ )
# we support that the value may have changed. so here we
# try to get the most recent value including re-fetching.
@@ -1089,19 +1281,19 @@ class Relationship(
state,
dict_,
column,
- passive=attributes.PASSIVE_OFF
+ passive=PassiveFlag.PASSIVE_OFF
if state.persistent
- else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK,
+ else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK,
)
- if current_value is attributes.NEVER_SET:
+ if current_value is LoaderCallableStatus.NEVER_SET:
if not existing_is_available:
raise sa_exc.InvalidRequestError(
"Can't resolve value for column %s on object "
"%s; no value has been set for this column"
% (column, state_str(state))
)
- elif current_value is attributes.PASSIVE_NO_RESULT:
+ elif current_value is LoaderCallableStatus.PASSIVE_NO_RESULT:
if not existing_is_available:
raise sa_exc.InvalidRequestError(
"Can't resolve value for column %s on object "
@@ -1121,7 +1313,11 @@ class Relationship(
return _go
- def _lazy_none_clause(self, reverse_direction=False, adapt_source=None):
+ def _lazy_none_clause(
+ self,
+ reverse_direction: bool = False,
+ adapt_source: Optional[_CoreAdapterProto] = None,
+ ) -> ColumnElement[bool]:
if not reverse_direction:
criterion, bind_to_col = (
self._lazy_strategy._lazywhere,
@@ -1139,20 +1335,20 @@ class Relationship(
criterion = adapt_source(criterion)
return criterion
- def __str__(self):
+ def __str__(self) -> str:
return str(self.parent.class_.__name__) + "." + self.key
def merge(
self,
- session,
- source_state,
- source_dict,
- dest_state,
- dest_dict,
- load,
- _recursive,
- _resolve_conflict_map,
- ):
+ session: Session,
+ source_state: InstanceState[Any],
+ source_dict: _InstanceDict,
+ dest_state: InstanceState[Any],
+ dest_dict: _InstanceDict,
+ load: bool,
+ _recursive: Dict[Any, object],
+ _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+ ) -> None:
if load:
for r in self._reverse_property:
@@ -1167,6 +1363,8 @@ class Relationship(
if self.uselist:
impl = source_state.get_impl(self.key)
+
+ assert is_has_collection_adapter(impl)
instances_iterable = impl.get_collection(source_state, source_dict)
# if this is a CollectionAttributeImpl, then empty should
@@ -1204,9 +1402,9 @@ class Relationship(
for c in dest_list:
coll.append_without_event(c)
else:
- dest_state.get_impl(self.key).set(
- dest_state, dest_dict, dest_list, _adapt=False
- )
+ dest_impl = dest_state.get_impl(self.key)
+ assert is_has_collection_adapter(dest_impl)
+ dest_impl.set(dest_state, dest_dict, dest_list, _adapt=False)
else:
current = source_dict[self.key]
if current is not None:
@@ -1231,8 +1429,12 @@ class Relationship(
)
def _value_as_iterable(
- self, state, dict_, key, passive=attributes.PASSIVE_OFF
- ):
+ self,
+ state: InstanceState[_O],
+ dict_: _InstanceDict,
+ key: str,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> Sequence[Tuple[InstanceState[_O], _O]]:
"""Return a list of tuples (state, obj) for the given
key.
@@ -1241,9 +1443,9 @@ class Relationship(
impl = state.manager[key].impl
x = impl.get(state, dict_, passive=passive)
- if x is attributes.PASSIVE_NO_RESULT or x is None:
+ if x is LoaderCallableStatus.PASSIVE_NO_RESULT or x is None:
return []
- elif hasattr(impl, "get_collection"):
+ elif is_has_collection_adapter(impl):
return [
(attributes.instance_state(o), o)
for o in impl.get_collection(state, dict_, x, passive=passive)
@@ -1252,19 +1454,23 @@ class Relationship(
return [(attributes.instance_state(x), x)]
def cascade_iterator(
- self, type_, state, dict_, visited_states, halt_on=None
- ):
+ self,
+ type_: str,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ visited_states: Set[InstanceState[Any]],
+ halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None,
+ ) -> Iterator[Tuple[Any, Mapper[Any], InstanceState[Any], _InstanceDict]]:
# assert type_ in self._cascade
# only actively lazy load on the 'delete' cascade
if type_ != "delete" or self.passive_deletes:
- passive = attributes.PASSIVE_NO_INITIALIZE
+ passive = PassiveFlag.PASSIVE_NO_INITIALIZE
else:
- passive = attributes.PASSIVE_OFF
+ passive = PassiveFlag.PASSIVE_OFF
if type_ == "save-update":
tuples = state.manager[self.key].impl.get_all_pending(state, dict_)
-
else:
tuples = self._value_as_iterable(
state, dict_, self.key, passive=passive
@@ -1285,6 +1491,7 @@ class Relationship(
# see [ticket:2229]
continue
+ assert instance_state is not None
instance_dict = attributes.instance_dict(c)
if halt_on and halt_on(instance_state):
@@ -1308,14 +1515,16 @@ class Relationship(
yield c, instance_mapper, instance_state, instance_dict
@property
- def _effective_sync_backref(self):
+ def _effective_sync_backref(self) -> bool:
if self.viewonly:
return False
else:
return self.sync_backref is not False
@staticmethod
- def _check_sync_backref(rel_a, rel_b):
+ def _check_sync_backref(
+ rel_a: Relationship[Any], rel_b: Relationship[Any]
+ ) -> None:
if rel_a.viewonly and rel_b.sync_backref:
raise sa_exc.InvalidRequestError(
"Relationship %s cannot specify sync_backref=True since %s "
@@ -1328,7 +1537,7 @@ class Relationship(
):
rel_b.sync_backref = False
- def _add_reverse_property(self, key):
+ def _add_reverse_property(self, key: str) -> None:
other = self.mapper.get_property(key, _configure_mappers=False)
if not isinstance(other, Relationship):
raise sa_exc.InvalidRequestError(
@@ -1361,7 +1570,8 @@ class Relationship(
)
if (
- self.direction in (ONETOMANY, MANYTOONE)
+ other._configure_started
+ and self.direction in (ONETOMANY, MANYTOONE)
and self.direction == other.direction
):
raise sa_exc.ArgumentError(
@@ -1372,7 +1582,7 @@ class Relationship(
)
@util.memoized_property
- def entity(self) -> Union["Mapper", "AliasedInsp"]:
+ def entity(self) -> _InternalEntityType[_T]:
"""Return the target mapped entity, which is an inspect() of the
class or aliased class that is referred towards.
@@ -1388,7 +1598,7 @@ class Relationship(
"""
return self.entity.mapper
- def do_init(self):
+ def do_init(self) -> None:
self._check_conflicts()
self._process_dependent_arguments()
self._setup_entity()
@@ -1399,14 +1609,16 @@ class Relationship(
self._generate_backref()
self._join_condition._warn_for_conflicting_sync_targets()
super(Relationship, self).do_init()
- self._lazy_strategy = self._get_strategy((("lazy", "select"),))
+ self._lazy_strategy = cast(
+ "LazyLoader", self._get_strategy((("lazy", "select"),))
+ )
- def _setup_registry_dependencies(self):
+ def _setup_registry_dependencies(self) -> None:
self.parent.mapper.registry._set_depends_on(
self.entity.mapper.registry
)
- def _process_dependent_arguments(self):
+ def _process_dependent_arguments(self) -> None:
"""Convert incoming configuration arguments to their
proper form.
@@ -1417,78 +1629,80 @@ class Relationship(
# accept callables for other attributes which may require
# deferred initialization. This technique is used
# by declarative "string configs" and some recipes.
+ init_args = self._init_args
+
for attr in (
"order_by",
"primaryjoin",
"secondaryjoin",
"secondary",
- "_user_defined_foreign_keys",
+ "foreign_keys",
"remote_side",
):
- attr_value = getattr(self, attr)
-
- if isinstance(attr_value, str):
- setattr(
- self,
- attr,
- self._clsregistry_resolve_arg(
- attr_value, favor_tables=attr == "secondary"
- )(),
- )
- elif callable(attr_value) and not _is_mapped_class(attr_value):
- setattr(self, attr, attr_value())
+
+ rel_arg = getattr(init_args, attr)
+
+ rel_arg._resolve_against_registry(self._clsregistry_resolvers[1])
# remove "annotations" which are present if mapped class
# descriptors are used to create the join expression.
for attr in "primaryjoin", "secondaryjoin":
- val = getattr(self, attr)
+ rel_arg = getattr(init_args, attr)
+ val = rel_arg.resolved
if val is not None:
- setattr(
- self,
- attr,
- _orm_deannotate(
- coercions.expect(
- roles.ColumnArgumentRole, val, argname=attr
- )
- ),
+ rel_arg.resolved = _orm_deannotate(
+ coercions.expect(
+ roles.ColumnArgumentRole, val, argname=attr
+ )
)
- if self.secondary is not None and _is_mapped_class(self.secondary):
+ secondary = init_args.secondary.resolved
+ if secondary is not None and _is_mapped_class(secondary):
raise sa_exc.ArgumentError(
"secondary argument %s passed to to relationship() %s must "
"be a Table object or other FROM clause; can't send a mapped "
"class directly as rows in 'secondary' are persisted "
"independently of a class that is mapped "
- "to that same table." % (self.secondary, self)
+ "to that same table." % (secondary, self)
)
# ensure expressions in self.order_by, foreign_keys,
# remote_side are all columns, not strings.
- if self.order_by is not False and self.order_by is not None:
+ if (
+ init_args.order_by.resolved is not False
+ and init_args.order_by.resolved is not None
+ ):
self.order_by = tuple(
coercions.expect(
roles.ColumnArgumentRole, x, argname="order_by"
)
- for x in util.to_list(self.order_by)
+ for x in util.to_list(init_args.order_by.resolved)
)
+ else:
+ self.order_by = False
self._user_defined_foreign_keys = util.column_set(
coercions.expect(
roles.ColumnArgumentRole, x, argname="foreign_keys"
)
- for x in util.to_column_set(self._user_defined_foreign_keys)
+ for x in util.to_column_set(init_args.foreign_keys.resolved)
)
self.remote_side = util.column_set(
coercions.expect(
roles.ColumnArgumentRole, x, argname="remote_side"
)
- for x in util.to_column_set(self.remote_side)
+ for x in util.to_column_set(init_args.remote_side.resolved)
)
def declarative_scan(
- self, registry, cls, key, annotation, is_dataclass_field
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
+ ) -> None:
argument = _extract_mapped_subtype(
annotation,
cls,
@@ -1502,17 +1716,19 @@ class Relationship(
if hasattr(argument, "__origin__"):
- collection_class = argument.__origin__
+ collection_class = argument.__origin__ # type: ignore
if issubclass(collection_class, abc.Collection):
if self.collection_class is None:
self.collection_class = collection_class
else:
self.uselist = False
- if argument.__args__:
- if issubclass(argument.__origin__, typing.Mapping):
- type_arg = argument.__args__[1]
+ if argument.__args__: # type: ignore
+ if issubclass(
+ argument.__origin__, typing.Mapping # type: ignore
+ ):
+ type_arg = argument.__args__[1] # type: ignore
else:
- type_arg = argument.__args__[0]
+ type_arg = argument.__args__[0] # type: ignore
if hasattr(type_arg, "__forward_arg__"):
str_argument = type_arg.__forward_arg__
argument = str_argument
@@ -1523,12 +1739,12 @@ class Relationship(
f"Generic alias {argument} requires an argument"
)
elif hasattr(argument, "__forward_arg__"):
- argument = argument.__forward_arg__
+ argument = argument.__forward_arg__ # type: ignore
self.argument = argument
@util.preload_module("sqlalchemy.orm.mapper")
- def _setup_entity(self, __argument=None):
+ def _setup_entity(self, __argument: Any = None) -> None:
if "entity" in self.__dict__:
return
@@ -1539,42 +1755,51 @@ class Relationship(
else:
argument = self.argument
+ resolved_argument: _ExternalEntityType[Any]
+
if isinstance(argument, str):
- argument = self._clsregistry_resolve_name(argument)()
+ # we might want to cleanup clsregistry API to make this
+ # more straightforward
+ resolved_argument = cast(
+ "_ExternalEntityType[Any]",
+ self._clsregistry_resolve_name(argument)(),
+ )
elif callable(argument) and not isinstance(
argument, (type, mapperlib.Mapper)
):
- argument = argument()
+ resolved_argument = argument()
else:
- argument = argument
+ resolved_argument = argument
- if isinstance(argument, type):
- entity = class_mapper(argument, configure=False)
+ entity: _InternalEntityType[Any]
+
+ if isinstance(resolved_argument, type):
+ entity = class_mapper(resolved_argument, configure=False)
else:
try:
- entity = inspect(argument)
+ entity = inspect(resolved_argument)
except sa_exc.NoInspectionAvailable:
- entity = None
+ entity = None # type: ignore
if not hasattr(entity, "mapper"):
raise sa_exc.ArgumentError(
"relationship '%s' expects "
"a class or a mapper argument (received: %s)"
- % (self.key, type(argument))
+ % (self.key, type(resolved_argument))
)
self.entity = entity # type: ignore
self.target = self.entity.persist_selectable
- def _setup_join_conditions(self):
+ def _setup_join_conditions(self) -> None:
self._join_condition = jc = JoinCondition(
parent_persist_selectable=self.parent.persist_selectable,
child_persist_selectable=self.entity.persist_selectable,
parent_local_selectable=self.parent.local_table,
child_local_selectable=self.entity.local_table,
- primaryjoin=self.primaryjoin,
- secondary=self.secondary,
- secondaryjoin=self.secondaryjoin,
+ primaryjoin=self._init_args.primaryjoin.resolved,
+ secondary=self._init_args.secondary.resolved,
+ secondaryjoin=self._init_args.secondaryjoin.resolved,
parent_equivalents=self.parent._equivalent_columns,
child_equivalents=self.mapper._equivalent_columns,
consider_as_foreign_keys=self._user_defined_foreign_keys,
@@ -1587,6 +1812,7 @@ class Relationship(
)
self.primaryjoin = jc.primaryjoin
self.secondaryjoin = jc.secondaryjoin
+ self.secondary = jc.secondary
self.direction = jc.direction
self.local_remote_pairs = jc.local_remote_pairs
self.remote_side = jc.remote_columns
@@ -1596,21 +1822,30 @@ class Relationship(
self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
@property
- def _clsregistry_resolve_arg(self):
+ def _clsregistry_resolve_arg(
+ self,
+ ) -> Callable[[str, bool], _class_resolver]:
return self._clsregistry_resolvers[1]
@property
- def _clsregistry_resolve_name(self):
+ def _clsregistry_resolve_name(
+ self,
+ ) -> Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]]:
return self._clsregistry_resolvers[0]
@util.memoized_property
@util.preload_module("sqlalchemy.orm.clsregistry")
- def _clsregistry_resolvers(self):
+ def _clsregistry_resolvers(
+ self,
+ ) -> Tuple[
+ Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
+ Callable[[str, bool], _class_resolver],
+ ]:
_resolver = util.preloaded.orm_clsregistry._resolver
return _resolver(self.parent.class_, self)
- def _check_conflicts(self):
+ def _check_conflicts(self) -> None:
"""Test that this relationship is legal, warn about
inheritance conflicts."""
if self.parent.non_primary and not class_mapper(
@@ -1637,10 +1872,10 @@ class Relationship(
return self._cascade
@cascade.setter
- def cascade(self, cascade: Union[str, CascadeOptions]):
+ def cascade(self, cascade: Union[str, CascadeOptions]) -> None:
self._set_cascade(cascade)
- def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]):
+ def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]) -> None:
cascade = CascadeOptions(cascade_arg)
if self.viewonly:
@@ -1655,7 +1890,7 @@ class Relationship(
if self._dependency_processor:
self._dependency_processor.cascade = cascade
- def _check_cascade_settings(self, cascade):
+ def _check_cascade_settings(self, cascade: CascadeOptions) -> None:
if (
cascade.delete_orphan
and not self.single_parent
@@ -1699,7 +1934,7 @@ class Relationship(
(self.key, self.parent.class_)
)
- def _persists_for(self, mapper):
+ def _persists_for(self, mapper: Mapper[Any]) -> bool:
"""Return True if this property will persist values on behalf
of the given mapper.
@@ -1710,16 +1945,15 @@ class Relationship(
and mapper.relationships[self.key] is self
)
- def _columns_are_mapped(self, *cols):
+ def _columns_are_mapped(self, *cols: ColumnElement[Any]) -> bool:
"""Return True if all columns in the given collection are
mapped by the tables referenced by this :class:`.Relationship`.
"""
+
+ secondary = self._init_args.secondary.resolved
for c in cols:
- if (
- self.secondary is not None
- and self.secondary.c.contains_column(c)
- ):
+ if secondary is not None and secondary.c.contains_column(c):
continue
if not self.parent.persist_selectable.c.contains_column(
c
@@ -1727,13 +1961,14 @@ class Relationship(
return False
return True
- def _generate_backref(self):
+ def _generate_backref(self) -> None:
"""Interpret the 'backref' instruction to create a
:func:`_orm.relationship` complementary to this one."""
if self.parent.non_primary:
return
if self.backref is not None and not self.back_populates:
+ kwargs: Dict[str, Any]
if isinstance(self.backref, str):
backref_key, kwargs = self.backref, {}
else:
@@ -1805,7 +2040,7 @@ class Relationship(
self._add_reverse_property(self.back_populates)
@util.preload_module("sqlalchemy.orm.dependency")
- def _post_init(self):
+ def _post_init(self) -> None:
dependency = util.preloaded.orm_dependency
if self.uselist is None:
@@ -1816,7 +2051,7 @@ class Relationship(
)(self)
@util.memoized_property
- def _use_get(self):
+ def _use_get(self) -> bool:
"""memoize the 'use_get' attribute of this RelationshipLoader's
lazyloader."""
@@ -1824,18 +2059,25 @@ class Relationship(
return strategy.use_get
@util.memoized_property
- def _is_self_referential(self):
+ def _is_self_referential(self) -> bool:
return self.mapper.common_parent(self.parent)
def _create_joins(
self,
- source_polymorphic=False,
- source_selectable=None,
- dest_selectable=None,
- of_type_entity=None,
- alias_secondary=False,
- extra_criteria=(),
- ):
+ source_polymorphic: bool = False,
+ source_selectable: Optional[FromClause] = None,
+ dest_selectable: Optional[FromClause] = None,
+ of_type_entity: Optional[_InternalEntityType[Any]] = None,
+ alias_secondary: bool = False,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
+ ) -> Tuple[
+ ColumnElement[bool],
+ Optional[ColumnElement[bool]],
+ FromClause,
+ FromClause,
+ Optional[FromClause],
+ Optional[ClauseAdapter],
+ ]:
aliased = False
@@ -1905,38 +2147,56 @@ class Relationship(
)
-def _annotate_columns(element, annotations):
- def clone(elem):
+def _annotate_columns(element: _CE, annotations: _AnnotationDict) -> _CE:
+ def clone(elem: _CE) -> _CE:
if isinstance(elem, expression.ColumnClause):
- elem = elem._annotate(annotations.copy())
+ elem = elem._annotate(annotations.copy()) # type: ignore
elem._copy_internals(clone=clone)
return elem
if element is not None:
element = clone(element)
- clone = None # remove gc cycles
+ clone = None # type: ignore # remove gc cycles
return element
class JoinCondition:
+
+ primaryjoin_initial: Optional[ColumnElement[bool]]
+ primaryjoin: ColumnElement[bool]
+ secondaryjoin: Optional[ColumnElement[bool]]
+ secondary: Optional[FromClause]
+ prop: Relationship[Any]
+
+ synchronize_pairs: _ColumnPairs
+ secondary_synchronize_pairs: _ColumnPairs
+ direction: RelationshipDirection
+
+ parent_persist_selectable: FromClause
+ child_persist_selectable: FromClause
+ parent_local_selectable: FromClause
+ child_local_selectable: FromClause
+
+ _local_remote_pairs: Optional[_ColumnPairs]
+
def __init__(
self,
- parent_persist_selectable,
- child_persist_selectable,
- parent_local_selectable,
- child_local_selectable,
- primaryjoin=None,
- secondary=None,
- secondaryjoin=None,
- parent_equivalents=None,
- child_equivalents=None,
- consider_as_foreign_keys=None,
- local_remote_pairs=None,
- remote_side=None,
- self_referential=False,
- prop=None,
- support_sync=True,
- can_be_synced_fn=lambda *c: True,
+ parent_persist_selectable: FromClause,
+ child_persist_selectable: FromClause,
+ parent_local_selectable: FromClause,
+ child_local_selectable: FromClause,
+ primaryjoin: Optional[ColumnElement[bool]] = None,
+ secondary: Optional[FromClause] = None,
+ secondaryjoin: Optional[ColumnElement[bool]] = None,
+ parent_equivalents: Optional[_EquivalentColumnMap] = None,
+ child_equivalents: Optional[_EquivalentColumnMap] = None,
+ consider_as_foreign_keys: Any = None,
+ local_remote_pairs: Optional[_ColumnPairs] = None,
+ remote_side: Any = None,
+ self_referential: Any = False,
+ prop: Optional[Relationship[Any]] = None,
+ support_sync: bool = True,
+ can_be_synced_fn: Callable[..., bool] = lambda *c: True,
):
self.parent_persist_selectable = parent_persist_selectable
self.parent_local_selectable = parent_local_selectable
@@ -1944,7 +2204,7 @@ class JoinCondition:
self.child_local_selectable = child_local_selectable
self.parent_equivalents = parent_equivalents
self.child_equivalents = child_equivalents
- self.primaryjoin = primaryjoin
+ self.primaryjoin_initial = primaryjoin
self.secondaryjoin = secondaryjoin
self.secondary = secondary
self.consider_as_foreign_keys = consider_as_foreign_keys
@@ -1954,7 +2214,10 @@ class JoinCondition:
self.self_referential = self_referential
self.support_sync = support_sync
self.can_be_synced_fn = can_be_synced_fn
+
self._determine_joins()
+ assert self.primaryjoin is not None
+
self._sanitize_joins()
self._annotate_fks()
self._annotate_remote()
@@ -1968,7 +2231,7 @@ class JoinCondition:
self._check_remote_side()
self._log_joins()
- def _log_joins(self):
+ def _log_joins(self) -> None:
if self.prop is None:
return
log = self.prop.logger
@@ -2008,7 +2271,7 @@ class JoinCondition:
)
log.info("%s relationship direction %s", self.prop, self.direction)
- def _sanitize_joins(self):
+ def _sanitize_joins(self) -> None:
"""remove the parententity annotation from our join conditions which
can leak in here based on some declarative patterns and maybe others.
@@ -2026,7 +2289,7 @@ class JoinCondition:
self.secondaryjoin, values=("parententity", "proxy_key")
)
- def _determine_joins(self):
+ def _determine_joins(self) -> None:
"""Determine the 'primaryjoin' and 'secondaryjoin' attributes,
if not passed to the constructor already.
@@ -2056,21 +2319,25 @@ class JoinCondition:
a_subset=self.child_local_selectable,
consider_as_foreign_keys=consider_as_foreign_keys,
)
- if self.primaryjoin is None:
+ if self.primaryjoin_initial is None:
self.primaryjoin = join_condition(
self.parent_persist_selectable,
self.secondary,
a_subset=self.parent_local_selectable,
consider_as_foreign_keys=consider_as_foreign_keys,
)
+ else:
+ self.primaryjoin = self.primaryjoin_initial
else:
- if self.primaryjoin is None:
+ if self.primaryjoin_initial is None:
self.primaryjoin = join_condition(
self.parent_persist_selectable,
self.child_persist_selectable,
a_subset=self.parent_local_selectable,
consider_as_foreign_keys=consider_as_foreign_keys,
)
+ else:
+ self.primaryjoin = self.primaryjoin_initial
except sa_exc.NoForeignKeysError as nfe:
if self.secondary is not None:
raise sa_exc.NoForeignKeysError(
@@ -2118,15 +2385,16 @@ class JoinCondition:
) from afe
@property
- def primaryjoin_minus_local(self):
+ def primaryjoin_minus_local(self) -> ColumnElement[bool]:
return _deep_deannotate(self.primaryjoin, values=("local", "remote"))
@property
- def secondaryjoin_minus_local(self):
+ def secondaryjoin_minus_local(self) -> ColumnElement[bool]:
+ assert self.secondaryjoin is not None
return _deep_deannotate(self.secondaryjoin, values=("local", "remote"))
@util.memoized_property
- def primaryjoin_reverse_remote(self):
+ def primaryjoin_reverse_remote(self) -> ColumnElement[bool]:
"""Return the primaryjoin condition suitable for the
"reverse" direction.
@@ -2138,7 +2406,7 @@ class JoinCondition:
"""
if self._has_remote_annotations:
- def replace(element):
+ def replace(element: _CE, **kw: Any) -> Optional[_CE]:
if "remote" in element._annotations:
v = dict(element._annotations)
del v["remote"]
@@ -2150,6 +2418,8 @@ class JoinCondition:
v["remote"] = True
return element._with_annotations(v)
+ return None
+
return visitors.replacement_traverse(self.primaryjoin, {}, replace)
else:
if self._has_foreign_annotations:
@@ -2160,7 +2430,7 @@ class JoinCondition:
else:
return _deep_deannotate(self.primaryjoin)
- def _has_annotation(self, clause, annotation):
+ def _has_annotation(self, clause: ClauseElement, annotation: str) -> bool:
for col in visitors.iterate(clause, {}):
if annotation in col._annotations:
return True
@@ -2168,14 +2438,14 @@ class JoinCondition:
return False
@util.memoized_property
- def _has_foreign_annotations(self):
+ def _has_foreign_annotations(self) -> bool:
return self._has_annotation(self.primaryjoin, "foreign")
@util.memoized_property
- def _has_remote_annotations(self):
+ def _has_remote_annotations(self) -> bool:
return self._has_annotation(self.primaryjoin, "remote")
- def _annotate_fks(self):
+ def _annotate_fks(self) -> None:
"""Annotate the primaryjoin and secondaryjoin
structures with 'foreign' annotations marking columns
considered as foreign.
@@ -2189,10 +2459,11 @@ class JoinCondition:
else:
self._annotate_present_fks()
- def _annotate_from_fk_list(self):
- def check_fk(col):
- if col in self.consider_as_foreign_keys:
- return col._annotate({"foreign": True})
+ def _annotate_from_fk_list(self) -> None:
+ def check_fk(element: _CE, **kw: Any) -> Optional[_CE]:
+ if element in self.consider_as_foreign_keys:
+ return element._annotate({"foreign": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, check_fk
@@ -2202,13 +2473,15 @@ class JoinCondition:
self.secondaryjoin, {}, check_fk
)
- def _annotate_present_fks(self):
+ def _annotate_present_fks(self) -> None:
if self.secondary is not None:
secondarycols = util.column_set(self.secondary.c)
else:
secondarycols = set()
- def is_foreign(a, b):
+ def is_foreign(
+ a: ColumnElement[Any], b: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
if isinstance(a, schema.Column) and isinstance(b, schema.Column):
if a.references(b):
return a
@@ -2221,7 +2494,9 @@ class JoinCondition:
elif b in secondarycols and a not in secondarycols:
return b
- def visit_binary(binary):
+ return None
+
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
if not isinstance(
binary.left, sql.ColumnElement
) or not isinstance(binary.right, sql.ColumnElement):
@@ -2248,16 +2523,17 @@ class JoinCondition:
self.secondaryjoin, {}, {"binary": visit_binary}
)
- def _refers_to_parent_table(self):
+ def _refers_to_parent_table(self) -> bool:
"""Return True if the join condition contains column
comparisons where both columns are in both tables.
"""
pt = self.parent_persist_selectable
mt = self.child_persist_selectable
- result = [False]
+ result = False
- def visit_binary(binary):
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
+ nonlocal result
c, f = binary.left, binary.right
if (
isinstance(c, expression.ColumnClause)
@@ -2267,19 +2543,19 @@ class JoinCondition:
and mt.is_derived_from(c.table)
and mt.is_derived_from(f.table)
):
- result[0] = True
+ result = True
visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary})
- return result[0]
+ return result
- def _tables_overlap(self):
+ def _tables_overlap(self) -> bool:
"""Return True if parent/child tables have some overlap."""
return selectables_overlap(
self.parent_persist_selectable, self.child_persist_selectable
)
- def _annotate_remote(self):
+ def _annotate_remote(self) -> None:
"""Annotate the primaryjoin and secondaryjoin
structures with 'remote' annotations marking columns
considered as part of the 'remote' side.
@@ -2301,30 +2577,38 @@ class JoinCondition:
else:
self._annotate_remote_distinct_selectables()
- def _annotate_remote_secondary(self):
+ def _annotate_remote_secondary(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when 'secondary' is present.
"""
- def repl(element):
- if self.secondary.c.contains_column(element):
+ assert self.secondary is not None
+ fixed_secondary = self.secondary
+
+ def repl(element: _CE, **kw: Any) -> Optional[_CE]:
+ if fixed_secondary.c.contains_column(element):
return element._annotate({"remote": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl
)
+
+ assert self.secondaryjoin is not None
self.secondaryjoin = visitors.replacement_traverse(
self.secondaryjoin, {}, repl
)
- def _annotate_selfref(self, fn, remote_side_given):
+ def _annotate_selfref(
+ self, fn: Callable[[ColumnElement[Any]], bool], remote_side_given: bool
+ ) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the relationship is detected as self-referential.
"""
- def visit_binary(binary):
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
equated = binary.left.compare(binary.right)
if isinstance(binary.left, expression.ColumnClause) and isinstance(
binary.right, expression.ColumnClause
@@ -2341,7 +2625,7 @@ class JoinCondition:
self.primaryjoin, {}, {"binary": visit_binary}
)
- def _annotate_remote_from_args(self):
+ def _annotate_remote_from_args(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the 'remote_side' or '_local_remote_pairs'
arguments are used.
@@ -2363,17 +2647,18 @@ class JoinCondition:
self._annotate_selfref(lambda col: col in remote_side, True)
else:
- def repl(element):
+ def repl(element: _CE, **kw: Any) -> Optional[_CE]:
# use set() to avoid generating ``__eq__()`` expressions
# against each element
if element in set(remote_side):
return element._annotate({"remote": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl
)
- def _annotate_remote_with_overlap(self):
+ def _annotate_remote_with_overlap(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the parent/child tables have some set of
tables in common, though is not a fully self-referential
@@ -2381,7 +2666,7 @@ class JoinCondition:
"""
- def visit_binary(binary):
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
binary.left, binary.right = proc_left_right(
binary.left, binary.right
)
@@ -2393,7 +2678,9 @@ class JoinCondition:
self.prop is not None and self.prop.mapper is not self.prop.parent
)
- def proc_left_right(left, right):
+ def proc_left_right(
+ left: ColumnElement[Any], right: ColumnElement[Any]
+ ) -> Tuple[ColumnElement[Any], ColumnElement[Any]]:
if isinstance(left, expression.ColumnClause) and isinstance(
right, expression.ColumnClause
):
@@ -2420,32 +2707,33 @@ class JoinCondition:
self.primaryjoin, {}, {"binary": visit_binary}
)
- def _annotate_remote_distinct_selectables(self):
+ def _annotate_remote_distinct_selectables(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the parent/child tables are entirely
separate.
"""
- def repl(element):
+ def repl(element: _CE, **kw: Any) -> Optional[_CE]:
if self.child_persist_selectable.c.contains_column(element) and (
not self.parent_local_selectable.c.contains_column(element)
or self.child_local_selectable.c.contains_column(element)
):
return element._annotate({"remote": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl
)
- def _warn_non_column_elements(self):
+ def _warn_non_column_elements(self) -> None:
util.warn(
"Non-simple column elements in primary "
"join condition for property %s - consider using "
"remote() annotations to mark the remote side." % self.prop
)
- def _annotate_local(self):
+ def _annotate_local(self) -> None:
"""Annotate the primaryjoin and secondaryjoin
structures with 'local' annotations.
@@ -2466,29 +2754,31 @@ class JoinCondition:
else:
local_side = util.column_set(self.parent_persist_selectable.c)
- def locals_(elem):
- if "remote" not in elem._annotations and elem in local_side:
- return elem._annotate({"local": True})
+ def locals_(element: _CE, **kw: Any) -> Optional[_CE]:
+ if "remote" not in element._annotations and element in local_side:
+ return element._annotate({"local": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, locals_
)
- def _annotate_parentmapper(self):
+ def _annotate_parentmapper(self) -> None:
if self.prop is None:
return
- def parentmappers_(elem):
- if "remote" in elem._annotations:
- return elem._annotate({"parentmapper": self.prop.mapper})
- elif "local" in elem._annotations:
- return elem._annotate({"parentmapper": self.prop.parent})
+ def parentmappers_(element: _CE, **kw: Any) -> Optional[_CE]:
+ if "remote" in element._annotations:
+ return element._annotate({"parentmapper": self.prop.mapper})
+ elif "local" in element._annotations:
+ return element._annotate({"parentmapper": self.prop.parent})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, parentmappers_
)
- def _check_remote_side(self):
+ def _check_remote_side(self) -> None:
if not self.local_remote_pairs:
raise sa_exc.ArgumentError(
"Relationship %s could "
@@ -2501,7 +2791,9 @@ class JoinCondition:
"the relationship." % (self.prop,)
)
- def _check_foreign_cols(self, join_condition, primary):
+ def _check_foreign_cols(
+ self, join_condition: ColumnElement[bool], primary: bool
+ ) -> None:
"""Check the foreign key columns collected and emit error
messages."""
@@ -2567,7 +2859,7 @@ class JoinCondition:
)
raise sa_exc.ArgumentError(err)
- def _determine_direction(self):
+ def _determine_direction(self) -> None:
"""Determine if this relationship is one to many, many to one,
many to many.
@@ -2651,7 +2943,9 @@ class JoinCondition:
"nor the child's mapped tables" % self.prop
)
- def _deannotate_pairs(self, collection):
+ def _deannotate_pairs(
+ self, collection: _ColumnPairIterable
+ ) -> _MutableColumnPairs:
"""provide deannotation for the various lists of
pairs, so that using them in hashes doesn't incur
high-overhead __eq__() comparisons against
@@ -2660,13 +2954,22 @@ class JoinCondition:
"""
return [(x._deannotate(), y._deannotate()) for x, y in collection]
- def _setup_pairs(self):
- sync_pairs = []
- lrp = util.OrderedSet([])
- secondary_sync_pairs = []
-
- def go(joincond, collection):
- def visit_binary(binary, left, right):
+ def _setup_pairs(self) -> None:
+ sync_pairs: _MutableColumnPairs = []
+ lrp: util.OrderedSet[
+ Tuple[ColumnElement[Any], ColumnElement[Any]]
+ ] = util.OrderedSet([])
+ secondary_sync_pairs: _MutableColumnPairs = []
+
+ def go(
+ joincond: ColumnElement[bool],
+ collection: _MutableColumnPairs,
+ ) -> None:
+ def visit_binary(
+ binary: BinaryExpression[Any],
+ left: ColumnElement[Any],
+ right: ColumnElement[Any],
+ ) -> None:
if (
"remote" in right._annotations
and "remote" not in left._annotations
@@ -2703,9 +3006,12 @@ class JoinCondition:
secondary_sync_pairs
)
- _track_overlapping_sync_targets = weakref.WeakKeyDictionary()
+ _track_overlapping_sync_targets: weakref.WeakKeyDictionary[
+ ColumnElement[Any],
+ weakref.WeakKeyDictionary[Relationship[Any], ColumnElement[Any]],
+ ] = weakref.WeakKeyDictionary()
- def _warn_for_conflicting_sync_targets(self):
+ def _warn_for_conflicting_sync_targets(self) -> None:
if not self.support_sync:
return
@@ -2793,18 +3099,20 @@ class JoinCondition:
self._track_overlapping_sync_targets[to_][self.prop] = from_
@util.memoized_property
- def remote_columns(self):
+ def remote_columns(self) -> Set[ColumnElement[Any]]:
return self._gather_join_annotations("remote")
@util.memoized_property
- def local_columns(self):
+ def local_columns(self) -> Set[ColumnElement[Any]]:
return self._gather_join_annotations("local")
@util.memoized_property
- def foreign_key_columns(self):
+ def foreign_key_columns(self) -> Set[ColumnElement[Any]]:
return self._gather_join_annotations("foreign")
- def _gather_join_annotations(self, annotation):
+ def _gather_join_annotations(
+ self, annotation: str
+ ) -> Set[ColumnElement[Any]]:
s = set(
self._gather_columns_with_annotation(self.primaryjoin, annotation)
)
@@ -2816,24 +3124,32 @@ class JoinCondition:
)
return {x._deannotate() for x in s}
- def _gather_columns_with_annotation(self, clause, *annotation):
- annotation = set(annotation)
+ def _gather_columns_with_annotation(
+ self, clause: ColumnElement[Any], *annotation: Iterable[str]
+ ) -> Set[ColumnElement[Any]]:
+ annotation_set = set(annotation)
return set(
[
- col
+ cast(ColumnElement[Any], col)
for col in visitors.iterate(clause, {})
- if annotation.issubset(col._annotations)
+ if annotation_set.issubset(col._annotations)
]
)
def join_targets(
self,
- source_selectable,
- dest_selectable,
- aliased,
- single_crit=None,
- extra_criteria=(),
- ):
+ source_selectable: Optional[FromClause],
+ dest_selectable: FromClause,
+ aliased: bool,
+ single_crit: Optional[ColumnElement[bool]] = None,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
+ ) -> Tuple[
+ ColumnElement[bool],
+ Optional[ColumnElement[bool]],
+ Optional[FromClause],
+ Optional[ClauseAdapter],
+ FromClause,
+ ]:
"""Given a source and destination selectable, create a
join between them.
@@ -2923,9 +3239,15 @@ class JoinCondition:
dest_selectable,
)
- def create_lazy_clause(self, reverse_direction=False):
- binds = util.column_dict()
- equated_columns = util.column_dict()
+ def create_lazy_clause(
+ self, reverse_direction: bool = False
+ ) -> Tuple[
+ ColumnElement[bool],
+ Dict[str, ColumnElement[Any]],
+ Dict[ColumnElement[Any], ColumnElement[Any]],
+ ]:
+ binds: Dict[ColumnElement[Any], BindParameter[Any]] = {}
+ equated_columns: Dict[ColumnElement[Any], ColumnElement[Any]] = {}
has_secondary = self.secondaryjoin is not None
@@ -2941,21 +3263,23 @@ class JoinCondition:
for l, r in self.local_remote_pairs:
equated_columns[l] = r
- def col_to_bind(col):
+ def col_to_bind(
+ element: ColumnElement[Any], **kw: Any
+ ) -> Optional[BindParameter[Any]]:
if (
- (not reverse_direction and "local" in col._annotations)
+ (not reverse_direction and "local" in element._annotations)
or reverse_direction
and (
- (has_secondary and col in lookup)
- or (not has_secondary and "remote" in col._annotations)
+ (has_secondary and element in lookup)
+ or (not has_secondary and "remote" in element._annotations)
)
):
- if col not in binds:
- binds[col] = sql.bindparam(
- None, None, type_=col.type, unique=True
+ if element not in binds:
+ binds[element] = sql.bindparam(
+ None, None, type_=element.type, unique=True
)
- return binds[col]
+ return binds[element]
return None
lazywhere = self.primaryjoin
@@ -2982,8 +3306,8 @@ class _ColInAnnotations:
__slots__ = ("name",)
- def __init__(self, name):
+ def __init__(self, name: str):
self.name = name
- def __call__(self, c):
+ def __call__(self, c: ClauseElement) -> bool:
return self.name in c._annotations