diff options
Diffstat (limited to 'lib/sqlalchemy/orm/strategy_options.py')
-rw-r--r-- | lib/sqlalchemy/orm/strategy_options.py | 514 |
1 files changed, 332 insertions, 182 deletions
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 63679dd27..7aed6dd7b 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -3,6 +3,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls """ @@ -12,18 +13,30 @@ from __future__ import annotations import typing from typing import Any +from typing import Callable from typing import cast -from typing import Mapping -from typing import NoReturn +from typing import Dict +from typing import Iterable from typing import Optional +from typing import overload +from typing import Sequence from typing import Tuple +from typing import Type +from typing import TypeVar from typing import Union from . import util as orm_util +from ._typing import insp_is_aliased_class +from ._typing import insp_is_attribute +from ._typing import insp_is_mapper +from ._typing import insp_is_mapper_property +from .attributes import QueryableAttribute from .base import InspectionAttr from .interfaces import LoaderOption from .path_registry import _DEFAULT_TOKEN from .path_registry import _WILDCARD_TOKEN +from .path_registry import AbstractEntityRegistry +from .path_registry import path_is_property from .path_registry import PathRegistry from .path_registry import TokenRegistry from .util import _orm_full_deannotate @@ -38,14 +51,37 @@ from ..sql import roles from ..sql import traversals from ..sql import visitors from ..sql.base import _generative +from ..util.typing import Final +from ..util.typing import Literal -_RELATIONSHIP_TOKEN = "relationship" -_COLUMN_TOKEN = "column" +_RELATIONSHIP_TOKEN: Final[Literal["relationship"]] = "relationship" +_COLUMN_TOKEN: Final[Literal["column"]] = "column" + +_FN = TypeVar("_FN", bound="Callable[..., Any]") if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _InternalEntityType + from .context import _MapperEntity + from .context import ORMCompileState + from .context import QueryContext + from .interfaces import _StrategyKey + from .interfaces import MapperProperty from .mapper import Mapper + from .path_registry import _PathRepresentation + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _FromClauseArgument + from ..sql.cache_key import _CacheKeyTraversalType + from ..sql.cache_key import CacheKey + +Self_AbstractLoad = TypeVar("Self_AbstractLoad", bound="_AbstractLoad") + +_AttrType = Union[str, "QueryableAttribute[Any]"] -Self_AbstractLoad = typing.TypeVar("Self_AbstractLoad", bound="_AbstractLoad") +_WildcardKeyType = Literal["relationship", "column"] +_StrategySpec = Dict[str, Any] +_OptsType = Dict[str, Any] +_AttrGroupType = Tuple[_AttrType, ...] class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): @@ -54,7 +90,12 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): _is_strategy_option = True propagate_to_loaders: bool - def contains_eager(self, attr, alias=None, _is_chain=False): + def contains_eager( + self: Self_AbstractLoad, + attr: _AttrType, + alias: Optional[_FromClauseArgument] = None, + _is_chain: bool = False, + ) -> Self_AbstractLoad: r"""Indicate that the given attribute should be eagerly loaded from columns stated manually in the query. @@ -94,9 +135,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ if alias is not None: if not isinstance(alias, str): - info = inspect(alias) - alias = info.selectable - + coerced_alias = coercions.expect(roles.FromClauseRole, alias) else: util.warn_deprecated( "Passing a string name for the 'alias' argument to " @@ -105,21 +144,28 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): "sqlalchemy.orm.aliased() construct.", version="1.4", ) + coerced_alias = alias elif getattr(attr, "_of_type", None): - ot = inspect(attr._of_type) - alias = ot.selectable + assert isinstance(attr, QueryableAttribute) + ot: Optional[_InternalEntityType[Any]] = inspect(attr._of_type) + assert ot is not None + coerced_alias = ot.selectable + else: + coerced_alias = None cloned = self._set_relationship_strategy( attr, {"lazy": "joined"}, propagate_to_loaders=False, - opts={"eager_from_alias": alias}, + opts={"eager_from_alias": coerced_alias}, _reconcile_to_other=True if _is_chain else None, ) return cloned - def load_only(self, *attrs): + def load_only( + self: Self_AbstractLoad, *attrs: _AttrType + ) -> Self_AbstractLoad: """Indicate that for a particular entity, only the given list of column-based attribute names should be loaded; all others will be deferred. @@ -159,11 +205,17 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): {"deferred": False, "instrument": True}, ) cloned = cloned._set_column_strategy( - "*", {"deferred": True, "instrument": True}, {"undefer_pks": True} + ("*",), + {"deferred": True, "instrument": True}, + {"undefer_pks": True}, ) return cloned - def joinedload(self, attr, innerjoin=None): + def joinedload( + self: Self_AbstractLoad, + attr: _AttrType, + innerjoin: Optional[bool] = None, + ) -> Self_AbstractLoad: """Indicate that the given attribute should be loaded using joined eager loading. @@ -258,7 +310,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): ) return loader - def subqueryload(self, attr): + def subqueryload( + self: Self_AbstractLoad, attr: _AttrType + ) -> Self_AbstractLoad: """Indicate that the given attribute should be loaded using subquery eager loading. @@ -289,7 +343,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ return self._set_relationship_strategy(attr, {"lazy": "subquery"}) - def selectinload(self, attr): + def selectinload( + self: Self_AbstractLoad, attr: _AttrType + ) -> Self_AbstractLoad: """Indicate that the given attribute should be loaded using SELECT IN eager loading. @@ -321,7 +377,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ return self._set_relationship_strategy(attr, {"lazy": "selectin"}) - def lazyload(self, attr): + def lazyload( + self: Self_AbstractLoad, attr: _AttrType + ) -> Self_AbstractLoad: """Indicate that the given attribute should be loaded using "lazy" loading. @@ -337,7 +395,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ return self._set_relationship_strategy(attr, {"lazy": "select"}) - def immediateload(self, attr): + def immediateload( + self: Self_AbstractLoad, attr: _AttrType + ) -> Self_AbstractLoad: """Indicate that the given attribute should be loaded using an immediate load with a per-attribute SELECT statement. @@ -361,7 +421,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): loader = self._set_relationship_strategy(attr, {"lazy": "immediate"}) return loader - def noload(self, attr): + def noload(self: Self_AbstractLoad, attr: _AttrType) -> Self_AbstractLoad: """Indicate that the given relationship attribute should remain unloaded. @@ -387,7 +447,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): return self._set_relationship_strategy(attr, {"lazy": "noload"}) - def raiseload(self, attr, sql_only=False): + def raiseload( + self: Self_AbstractLoad, attr: _AttrType, sql_only: bool = False + ) -> Self_AbstractLoad: """Indicate that the given attribute should raise an error if accessed. A relationship attribute configured with :func:`_orm.raiseload` will @@ -428,7 +490,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): attr, {"lazy": "raise_on_sql" if sql_only else "raise"} ) - def defaultload(self, attr): + def defaultload( + self: Self_AbstractLoad, attr: _AttrType + ) -> Self_AbstractLoad: """Indicate an attribute should load using its default loader style. This method is used to link to other loader options further into @@ -463,7 +527,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ return self._set_relationship_strategy(attr, None) - def defer(self, key, raiseload=False): + def defer( + self: Self_AbstractLoad, key: _AttrType, raiseload: bool = False + ) -> Self_AbstractLoad: r"""Indicate that the given column-oriented attribute should be deferred, e.g. not loaded until accessed. @@ -524,7 +590,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): strategy["raiseload"] = True return self._set_column_strategy((key,), strategy) - def undefer(self, key): + def undefer(self: Self_AbstractLoad, key: _AttrType) -> Self_AbstractLoad: r"""Indicate that the given column-oriented attribute should be undeferred, e.g. specified within the SELECT statement of the entity as a whole. @@ -538,7 +604,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): Examples:: # undefer two columns - session.query(MyClass).options(undefer("col1"), undefer("col2")) + session.query(MyClass).options( + undefer(MyClass.col1), undefer(MyClass.col2) + ) # undefer all columns specific to a single class using Load + * session.query(MyClass, MyOtherClass).options( @@ -546,7 +614,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): # undefer a column on a related object session.query(MyClass).options( - defaultload(MyClass.items).undefer('text')) + defaultload(MyClass.items).undefer(MyClass.text)) :param key: Attribute to be undeferred. @@ -563,7 +631,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): (key,), {"deferred": False, "instrument": True} ) - def undefer_group(self, name): + def undefer_group(self: Self_AbstractLoad, name: str) -> Self_AbstractLoad: """Indicate that columns within the given deferred group name should be undeferred. @@ -591,10 +659,14 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ return self._set_column_strategy( - _WILDCARD_TOKEN, None, {f"undefer_group_{name}": True} + (_WILDCARD_TOKEN,), None, {f"undefer_group_{name}": True} ) - def with_expression(self, key, expression): + def with_expression( + self: Self_AbstractLoad, + key: _AttrType, + expression: _ColumnExpressionArgument[Any], + ) -> Self_AbstractLoad: r"""Apply an ad-hoc SQL expression to a "deferred expression" attribute. @@ -626,15 +698,17 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ - expression = coercions.expect( - roles.LabeledColumnExprRole, _orm_full_deannotate(expression) + expression = _orm_full_deannotate( + coercions.expect(roles.LabeledColumnExprRole, expression) ) return self._set_column_strategy( (key,), {"query_expression": True}, opts={"expression": expression} ) - def selectin_polymorphic(self, classes): + def selectin_polymorphic( + self: Self_AbstractLoad, classes: Iterable[Type[Any]] + ) -> Self_AbstractLoad: """Indicate an eager load should take place for all attributes specific to a subclass. @@ -659,25 +733,37 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): ) return self - def _coerce_strat(self, strategy): + @overload + def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey: + ... + + @overload + def _coerce_strat(self, strategy: Literal[None]) -> None: + ... + + def _coerce_strat( + self, strategy: Optional[_StrategySpec] + ) -> Optional[_StrategyKey]: if strategy is not None: - strategy = tuple(sorted(strategy.items())) - return strategy + strategy_key = tuple(sorted(strategy.items())) + else: + strategy_key = None + return strategy_key @_generative def _set_relationship_strategy( self: Self_AbstractLoad, - attr, - strategy, - propagate_to_loaders=True, - opts=None, - _reconcile_to_other=None, + attr: _AttrType, + strategy: Optional[_StrategySpec], + propagate_to_loaders: bool = True, + opts: Optional[_OptsType] = None, + _reconcile_to_other: Optional[bool] = None, ) -> Self_AbstractLoad: - strategy = self._coerce_strat(strategy) + strategy_key = self._coerce_strat(strategy) self._clone_for_bind_strategy( (attr,), - strategy, + strategy_key, _RELATIONSHIP_TOKEN, opts=opts, propagate_to_loaders=propagate_to_loaders, @@ -687,13 +773,16 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): @_generative def _set_column_strategy( - self: Self_AbstractLoad, attrs, strategy, opts=None + self: Self_AbstractLoad, + attrs: Tuple[_AttrType, ...], + strategy: Optional[_StrategySpec], + opts: Optional[_OptsType] = None, ) -> Self_AbstractLoad: - strategy = self._coerce_strat(strategy) + strategy_key = self._coerce_strat(strategy) self._clone_for_bind_strategy( attrs, - strategy, + strategy_key, _COLUMN_TOKEN, opts=opts, attr_group=attrs, @@ -702,12 +791,15 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): @_generative def _set_generic_strategy( - self: Self_AbstractLoad, attrs, strategy, _reconcile_to_other=None + self: Self_AbstractLoad, + attrs: Tuple[_AttrType, ...], + strategy: _StrategySpec, + _reconcile_to_other: Optional[bool] = None, ) -> Self_AbstractLoad: - strategy = self._coerce_strat(strategy) + strategy_key = self._coerce_strat(strategy) self._clone_for_bind_strategy( attrs, - strategy, + strategy_key, None, propagate_to_loaders=True, reconcile_to_other=_reconcile_to_other, @@ -716,14 +808,14 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): @_generative def _set_class_strategy( - self: Self_AbstractLoad, strategy, opts + self: Self_AbstractLoad, strategy: _StrategySpec, opts: _OptsType ) -> Self_AbstractLoad: - strategy = self._coerce_strat(strategy) + strategy_key = self._coerce_strat(strategy) - self._clone_for_bind_strategy(None, strategy, None, opts=opts) + self._clone_for_bind_strategy(None, strategy_key, None, opts=opts) return self - def _apply_to_parent(self, parent): + def _apply_to_parent(self, parent: Load) -> None: """apply this :class:`_orm._AbstractLoad` object as a sub-option o a :class:`_orm.Load` object. @@ -732,7 +824,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ raise NotImplementedError() - def options(self: Self_AbstractLoad, *opts) -> NoReturn: + def options( + self: Self_AbstractLoad, *opts: _AbstractLoad + ) -> Self_AbstractLoad: r"""Apply a series of options as sub-options to this :class:`_orm._AbstractLoad` object. @@ -742,20 +836,22 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): raise NotImplementedError() def _clone_for_bind_strategy( - self, - attrs, - strategy, - wildcard_key, - opts=None, - attr_group=None, - propagate_to_loaders=True, - reconcile_to_other=None, - ): + self: Self_AbstractLoad, + attrs: Optional[Tuple[_AttrType, ...]], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + opts: Optional[_OptsType] = None, + attr_group: Optional[_AttrGroupType] = None, + propagate_to_loaders: bool = True, + reconcile_to_other: Optional[bool] = None, + ) -> Self_AbstractLoad: raise NotImplementedError() def process_compile_state_replaced_entities( - self, compile_state, mapper_entities - ): + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + ) -> None: if not compile_state.compile_options._enable_eagerloads: return @@ -768,7 +864,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): not bool(compile_state.current_path), ) - def process_compile_state(self, compile_state): + def process_compile_state(self, compile_state: ORMCompileState) -> None: if not compile_state.compile_options._enable_eagerloads: return @@ -779,12 +875,22 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): and not compile_state.compile_options._for_refresh_state, ) - def _process(self, compile_state, mapper_entities, raiseerr): + def _process( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + raiseerr: bool, + ) -> None: """implemented by subclasses""" raise NotImplementedError() @classmethod - def _chop_path(cls, to_chop, path, debug=False): + def _chop_path( + cls, + to_chop: _PathRepresentation, + path: PathRegistry, + debug: bool = False, + ) -> Optional[_PathRepresentation]: i = -1 for i, (c_token, p_token) in enumerate(zip(to_chop, path.path)): @@ -793,7 +899,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): return to_chop elif ( c_token != f"{_RELATIONSHIP_TOKEN}:{_WILDCARD_TOKEN}" - and c_token != p_token.key + and c_token != p_token.key # type: ignore ): return None @@ -801,9 +907,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): continue elif ( isinstance(c_token, InspectionAttr) - and c_token.is_mapper + and insp_is_mapper(c_token) and ( - (p_token.is_mapper and c_token.isa(p_token)) + (insp_is_mapper(p_token) and c_token.isa(p_token)) or ( # a too-liberal check here to allow a path like # A->A.bs->B->B.cs->C->C.ds, natural path, to chop @@ -827,10 +933,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): # test_of_type.py->test_all_subq_query # i >= 2 - and p_token.is_aliased_class + and insp_is_aliased_class(p_token) and p_token._is_with_polymorphic and c_token in p_token.with_polymorphic_mappers - # and (breakpoint() or True) ) ) ): @@ -841,7 +946,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): return to_chop[i + 1 :] -SelfLoad = typing.TypeVar("SelfLoad", bound="Load") +SelfLoad = TypeVar("SelfLoad", bound="Load") class Load(_AbstractLoad): @@ -903,28 +1008,28 @@ class Load(_AbstractLoad): _cache_key_traversal = None path: PathRegistry - context: Tuple["_LoadElement", ...] + context: Tuple[_LoadElement, ...] - def __init__(self, entity): - insp = cast(Union["Mapper", AliasedInsp], inspect(entity)) + def __init__(self, entity: _EntityType[Any]): + insp = cast("Union[Mapper[Any], AliasedInsp[Any]]", inspect(entity)) insp._post_inspect self.path = insp._path_registry self.context = () self.propagate_to_loaders = False - def __str__(self): + def __str__(self) -> str: return f"Load({self.path[0]})" @classmethod - def _construct_for_existing_path(cls, path): + def _construct_for_existing_path(cls, path: PathRegistry) -> Load: load = cls.__new__(cls) load.path = path load.context = () load.propagate_to_loaders = False return load - def _adjust_for_extra_criteria(self, context): + def _adjust_for_extra_criteria(self, context: QueryContext) -> Load: """Apply the current bound parameters in a QueryContext to all occurrences "extra_criteria" stored within this ``Load`` object, returning a new instance of this ``Load`` object. @@ -932,10 +1037,10 @@ class Load(_AbstractLoad): """ orig_query = context.compile_state.select_statement - orig_cache_key = None - replacement_cache_key = None + orig_cache_key: Optional[CacheKey] = None + replacement_cache_key: Optional[CacheKey] = None - def process(opt): + def process(opt: _LoadElement) -> _LoadElement: if not opt._extra_criteria: return opt @@ -948,6 +1053,9 @@ class Load(_AbstractLoad): orig_cache_key = orig_query._generate_cache_key() replacement_cache_key = context.query._generate_cache_key() + assert orig_cache_key is not None + assert replacement_cache_key is not None + opt._extra_criteria = tuple( replacement_cache_key._apply_params_to_element( orig_cache_key, crit @@ -975,12 +1083,22 @@ class Load(_AbstractLoad): ezero = None for ent in mapper_entities: ezero = ent.entity_zero - if ezero and orm_util._entity_corresponds_to(ezero, path[0]): + if ezero and orm_util._entity_corresponds_to( + # technically this can be a token also, but this is + # safe to pass to _entity_corresponds_to() + ezero, + cast("_InternalEntityType[Any]", path[0]), + ): return ezero return None - def _process(self, compile_state, mapper_entities, raiseerr): + def _process( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + raiseerr: bool, + ) -> None: reconciled_lead_entity = self._reconcile_query_entities_with_us( mapper_entities, raiseerr @@ -995,7 +1113,7 @@ class Load(_AbstractLoad): raiseerr, ) - def _apply_to_parent(self, parent): + def _apply_to_parent(self, parent: Load) -> None: """apply this :class:`_orm.Load` object as a sub-option of another :class:`_orm.Load` object. @@ -1007,7 +1125,8 @@ class Load(_AbstractLoad): assert cloned.propagate_to_loaders == self.propagate_to_loaders if not orm_util._entity_corresponds_to_use_path_impl( - parent.path[-1], cloned.path[0] + cast("_InternalEntityType[Any]", parent.path[-1]), + cast("_InternalEntityType[Any]", cloned.path[0]), ): raise sa_exc.ArgumentError( f'Attribute "{cloned.path[1]}" does not link ' @@ -1025,7 +1144,7 @@ class Load(_AbstractLoad): parent.context += cloned.context @_generative - def options(self: SelfLoad, *opts) -> SelfLoad: + def options(self: SelfLoad, *opts: _AbstractLoad) -> SelfLoad: r"""Apply a series of options as sub-options to this :class:`_orm.Load` object. @@ -1062,38 +1181,36 @@ class Load(_AbstractLoad): return self def _clone_for_bind_strategy( - self, - attrs, - strategy, - wildcard_key, - opts=None, - attr_group=None, - propagate_to_loaders=True, - reconcile_to_other=None, - ) -> None: + self: SelfLoad, + attrs: Optional[Tuple[_AttrType, ...]], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + opts: Optional[_OptsType] = None, + attr_group: Optional[_AttrGroupType] = None, + propagate_to_loaders: bool = True, + reconcile_to_other: Optional[bool] = None, + ) -> SelfLoad: # for individual strategy that needs to propagate, set the whole # Load container to also propagate, so that it shows up in # InstanceState.load_options if propagate_to_loaders: self.propagate_to_loaders = True - if not self.path.has_entity: - if self.path.is_token: + if self.path.is_token: + raise sa_exc.ArgumentError( + "Wildcard token cannot be followed by another entity" + ) + + elif path_is_property(self.path): + # re-use the lookup which will raise a nicely formatted + # LoaderStrategyException + if strategy: + self.path.prop._strategy_lookup(self.path.prop, strategy[0]) + else: raise sa_exc.ArgumentError( - "Wildcard token cannot be followed by another entity" + f"Mapped attribute '{self.path.prop}' does not " + "refer to a mapped entity" ) - else: - # re-use the lookup which will raise a nicely formatted - # LoaderStrategyException - if strategy: - self.path.prop._strategy_lookup( - self.path.prop, strategy[0] - ) - else: - raise sa_exc.ArgumentError( - f"Mapped attribute '{self.path.prop}' does not " - "refer to a mapped entity" - ) if attrs is None: load_element = _ClassStrategyLoad.create( @@ -1140,6 +1257,7 @@ class Load(_AbstractLoad): if wildcard_key is _RELATIONSHIP_TOKEN: self.path = load_element.path self.context += (load_element,) + return self def __getstate__(self): d = self._shallow_to_dict() @@ -1151,7 +1269,7 @@ class Load(_AbstractLoad): self._shallow_from_dict(state) -SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad") +SelfWildcardLoad = TypeVar("SelfWildcardLoad", bound="_WildcardLoad") class _WildcardLoad(_AbstractLoad): @@ -1167,14 +1285,14 @@ class _WildcardLoad(_AbstractLoad): visitors.ExtendedInternalTraversal.dp_string_multi_dict, ), ] - cache_key_traversal = None + cache_key_traversal: _CacheKeyTraversalType = None strategy: Optional[Tuple[Any, ...]] - local_opts: Mapping[str, Any] + local_opts: _OptsType path: Tuple[str, ...] propagate_to_loaders = False - def __init__(self): + def __init__(self) -> None: self.path = () self.strategy = None self.local_opts = util.EMPTY_DICT @@ -1189,6 +1307,7 @@ class _WildcardLoad(_AbstractLoad): propagate_to_loaders=True, reconcile_to_other=None, ): + assert attrs is not None attr = attrs[0] assert ( wildcard_key @@ -1203,10 +1322,12 @@ class _WildcardLoad(_AbstractLoad): if opts: self.local_opts = util.immutabledict(opts) - def options(self: SelfWildcardLoad, *opts) -> SelfWildcardLoad: + def options( + self: SelfWildcardLoad, *opts: _AbstractLoad + ) -> SelfWildcardLoad: raise NotImplementedError("Star option does not support sub-options") - def _apply_to_parent(self, parent): + def _apply_to_parent(self, parent: Load) -> None: """apply this :class:`_orm._WildcardLoad` object as a sub-option of a :class:`_orm.Load` object. @@ -1215,12 +1336,11 @@ class _WildcardLoad(_AbstractLoad): it may be used as the sub-option of a :class:`_orm.Load` object. """ - attr = self.path[0] if attr.endswith(_DEFAULT_TOKEN): attr = f"{attr.split(':')[0]}:{_WILDCARD_TOKEN}" - effective_path = parent.path.token(attr) + effective_path = cast(AbstractEntityRegistry, parent.path).token(attr) assert effective_path.is_token @@ -1244,20 +1364,21 @@ class _WildcardLoad(_AbstractLoad): entities = [ent.entity_zero for ent in mapper_entities] current_path = compile_state.current_path - start_path = self.path + start_path: _PathRepresentation = self.path # TODO: chop_path already occurs in loader.process_compile_state() # so we will seek to simplify this if current_path: - start_path = self._chop_path(start_path, current_path) - if not start_path: + new_path = self._chop_path(start_path, current_path) + if not new_path: return + start_path = new_path # start_path is a single-token tuple assert start_path and len(start_path) == 1 token = start_path[0] - + assert isinstance(token, str) entity = self._find_entity_basestring(entities, token, raiseerr) if not entity: @@ -1270,6 +1391,7 @@ class _WildcardLoad(_AbstractLoad): # we just located, then go through the rest of our path # tokens and populate into the Load(). + assert isinstance(token, str) loader = _TokenStrategyLoad.create( path_element._path_registry, token, @@ -1291,7 +1413,12 @@ class _WildcardLoad(_AbstractLoad): return loader - def _find_entity_basestring(self, entities, token, raiseerr): + def _find_entity_basestring( + self, + entities: Iterable[_InternalEntityType[Any]], + token: str, + raiseerr: bool, + ) -> Optional[_InternalEntityType[Any]]: if token.endswith(f":{_WILDCARD_TOKEN}"): if len(list(entities)) != 1: if raiseerr: @@ -1324,11 +1451,11 @@ class _WildcardLoad(_AbstractLoad): else: return None - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: d = self._shallow_to_dict() return d - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: self._shallow_from_dict(state) @@ -1372,38 +1499,38 @@ class _LoadElement( _extra_criteria: Tuple[Any, ...] _reconcile_to_other: Optional[bool] - strategy: Tuple[Any, ...] + strategy: Optional[_StrategyKey] path: PathRegistry propagate_to_loaders: bool - local_opts: Mapping[str, Any] + local_opts: util.immutabledict[str, Any] is_token_strategy: bool is_class_strategy: bool - def __hash__(self): + def __hash__(self) -> int: return id(self) def __eq__(self, other): return traversals.compare(self, other) @property - def is_opts_only(self): + def is_opts_only(self) -> bool: return bool(self.local_opts and self.strategy is None) - def _clone(self): + def _clone(self, **kw: Any) -> _LoadElement: cls = self.__class__ s = cls.__new__(cls) self._shallow_copy_to(s) return s - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: d = self._shallow_to_dict() d["path"] = self.path.serialize() return d - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: state["path"] = PathRegistry.deserialize(state["path"]) self._shallow_from_dict(state) @@ -1437,8 +1564,8 @@ class _LoadElement( ) def _adjust_effective_path_for_current_path( - self, effective_path, current_path - ): + self, effective_path: PathRegistry, current_path: PathRegistry + ) -> Optional[PathRegistry]: """receives the 'current_path' entry from an :class:`.ORMCompileState` instance, which is set during lazy loads and secondary loader strategy loads, and adjusts the given path to be relative to the @@ -1456,7 +1583,7 @@ class _LoadElement( """ - chopped_start_path = Load._chop_path(effective_path, current_path) + chopped_start_path = Load._chop_path(effective_path.path, current_path) if not chopped_start_path: return None @@ -1523,16 +1650,16 @@ class _LoadElement( @classmethod def create( cls, - path, - attr, - strategy, - wildcard_key, - local_opts, - propagate_to_loaders, - raiseerr=True, - attr_group=None, - reconcile_to_other=None, - ): + path: PathRegistry, + attr: Optional[_AttrType], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + local_opts: Optional[_OptsType], + propagate_to_loaders: bool, + raiseerr: bool = True, + attr_group: Optional[_AttrGroupType] = None, + reconcile_to_other: Optional[bool] = None, + ) -> _LoadElement: """Create a new :class:`._LoadElement` object.""" opt = cls.__new__(cls) @@ -1554,14 +1681,14 @@ class _LoadElement( path = opt._init_path(path, attr, wildcard_key, attr_group, raiseerr) if not path: - return None + return None # type: ignore assert opt.is_token_strategy == path.is_token opt.path = path return opt - def __init__(self, path, strategy, local_opts, propagate_to_loaders): + def __init__(self) -> None: raise NotImplementedError() def _prepend_path_from(self, parent): @@ -1580,7 +1707,8 @@ class _LoadElement( assert cloned.is_class_strategy == self.is_class_strategy if not orm_util._entity_corresponds_to_use_path_impl( - parent.path[-1], cloned.path[0] + cast("_InternalEntityType[Any]", parent.path[-1]), + cast("_InternalEntityType[Any]", cloned.path[0]), ): raise sa_exc.ArgumentError( f'Attribute "{cloned.path[1]}" does not link ' @@ -1592,7 +1720,9 @@ class _LoadElement( return cloned @staticmethod - def _reconcile(replacement, existing): + def _reconcile( + replacement: _LoadElement, existing: _LoadElement + ) -> _LoadElement: """define behavior for when two Load objects are to be put into the context.attributes under the same key. @@ -1670,7 +1800,7 @@ class _AttributeStrategyLoad(_LoadElement): ), ] - _of_type: Union["Mapper", AliasedInsp, None] + _of_type: Union["Mapper[Any]", "AliasedInsp[Any]", None] _path_with_polymorphic_path: Optional[PathRegistry] is_class_strategy = False @@ -1812,7 +1942,7 @@ class _AttributeStrategyLoad(_LoadElement): pwpi = inspect( orm_util.AliasedInsp._with_polymorphic_factory( pwpi.mapper.base_mapper, - pwpi.mapper, + (pwpi.mapper,), aliased=True, _use_mapper_path=True, ) @@ -1820,11 +1950,12 @@ class _AttributeStrategyLoad(_LoadElement): start_path = self._path_with_polymorphic_path if current_path: - start_path = self._adjust_effective_path_for_current_path( + new_path = self._adjust_effective_path_for_current_path( start_path, current_path ) - if start_path is None: + if new_path is None: return + start_path = new_path key = ("path_with_polymorphic", start_path.natural_path) if key in context: @@ -1872,6 +2003,7 @@ class _AttributeStrategyLoad(_LoadElement): effective_path = self.path if current_path: + assert effective_path is not None effective_path = self._adjust_effective_path_for_current_path( effective_path, current_path ) @@ -1985,11 +2117,12 @@ class _TokenStrategyLoad(_LoadElement): ) if current_path: - effective_path = self._adjust_effective_path_for_current_path( + new_effective_path = self._adjust_effective_path_for_current_path( effective_path, current_path ) - if effective_path is None: + if new_effective_path is None: return [] + effective_path = new_effective_path # for a wildcard token, expand out the path we set # to encompass everything from the query entity on @@ -2048,19 +2181,25 @@ class _ClassStrategyLoad(_LoadElement): effective_path = self.path if current_path: - effective_path = self._adjust_effective_path_for_current_path( + new_effective_path = self._adjust_effective_path_for_current_path( effective_path, current_path ) - if effective_path is None: + if new_effective_path is None: return [] + effective_path = new_effective_path - return [("loader", cast(PathRegistry, effective_path).natural_path)] + return [("loader", effective_path.natural_path)] -def _generate_from_keys(meth, keys, chained, kw) -> _AbstractLoad: - - lead_element = None +def _generate_from_keys( + meth: Callable[..., _AbstractLoad], + keys: Tuple[_AttrType, ...], + chained: bool, + kw: Any, +) -> _AbstractLoad: + lead_element: Optional[_AbstractLoad] = None + attr: Any for is_default, _keys in (True, keys[0:-1]), (False, keys[-1:]): for attr in _keys: if isinstance(attr, str): @@ -2116,7 +2255,9 @@ def _generate_from_keys(meth, keys, chained, kw) -> _AbstractLoad: return lead_element -def _parse_attr_argument(attr): +def _parse_attr_argument( + attr: _AttrType, +) -> Tuple[InspectionAttr, _InternalEntityType[Any], MapperProperty[Any]]: """parse an attribute or wildcard argument to produce an :class:`._AbstractLoad` instance. @@ -2126,16 +2267,21 @@ def _parse_attr_argument(attr): """ try: - insp = inspect(attr) + # TODO: need to figure out this None thing being returned by + # inspect(), it should not have None as an option in most cases + # if at all + insp: InspectionAttr = inspect(attr) # type: ignore except sa_exc.NoInspectionAvailable as err: raise sa_exc.ArgumentError( "expected ORM mapped attribute for loader strategy argument" ) from err - if insp.is_property: + lead_entity: _InternalEntityType[Any] + + if insp_is_mapper_property(insp): lead_entity = insp.parent prop = insp - elif insp.is_attribute: + elif insp_is_attribute(insp): lead_entity = insp.parent prop = insp.prop else: @@ -2146,7 +2292,7 @@ def _parse_attr_argument(attr): return insp, lead_entity, prop -def loader_unbound_fn(fn): +def loader_unbound_fn(fn: _FN) -> _FN: """decorator that applies docstrings between standalone loader functions and the loader methods on :class:`._AbstractLoad`. @@ -2169,12 +2315,12 @@ See :func:`_orm.{fn.__name__}` for usage examples. @loader_unbound_fn -def contains_eager(*keys, **kw) -> _AbstractLoad: +def contains_eager(*keys: _AttrType, **kw: Any) -> _AbstractLoad: return _generate_from_keys(Load.contains_eager, keys, True, kw) @loader_unbound_fn -def load_only(*attrs) -> _AbstractLoad: +def load_only(*attrs: _AttrType) -> _AbstractLoad: # TODO: attrs against different classes. we likely have to # add some extra state to Load of some kind _, lead_element, _ = _parse_attr_argument(attrs[0]) @@ -2182,47 +2328,47 @@ def load_only(*attrs) -> _AbstractLoad: @loader_unbound_fn -def joinedload(*keys, **kw) -> _AbstractLoad: +def joinedload(*keys: _AttrType, **kw: Any) -> _AbstractLoad: return _generate_from_keys(Load.joinedload, keys, False, kw) @loader_unbound_fn -def subqueryload(*keys) -> _AbstractLoad: +def subqueryload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.subqueryload, keys, False, {}) @loader_unbound_fn -def selectinload(*keys) -> _AbstractLoad: +def selectinload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.selectinload, keys, False, {}) @loader_unbound_fn -def lazyload(*keys) -> _AbstractLoad: +def lazyload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.lazyload, keys, False, {}) @loader_unbound_fn -def immediateload(*keys) -> _AbstractLoad: +def immediateload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.immediateload, keys, False, {}) @loader_unbound_fn -def noload(*keys) -> _AbstractLoad: +def noload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.noload, keys, False, {}) @loader_unbound_fn -def raiseload(*keys, **kw) -> _AbstractLoad: +def raiseload(*keys: _AttrType, **kw: Any) -> _AbstractLoad: return _generate_from_keys(Load.raiseload, keys, False, kw) @loader_unbound_fn -def defaultload(*keys) -> _AbstractLoad: +def defaultload(*keys: _AttrType) -> _AbstractLoad: return _generate_from_keys(Load.defaultload, keys, False, {}) @loader_unbound_fn -def defer(key, *addl_attrs, **kw) -> _AbstractLoad: +def defer(key: _AttrType, *addl_attrs: _AttrType, **kw: Any) -> _AbstractLoad: if addl_attrs: util.warn_deprecated( "The *addl_attrs on orm.defer is deprecated. Please use " @@ -2234,7 +2380,7 @@ def defer(key, *addl_attrs, **kw) -> _AbstractLoad: @loader_unbound_fn -def undefer(key, *addl_attrs) -> _AbstractLoad: +def undefer(key: _AttrType, *addl_attrs: _AttrType) -> _AbstractLoad: if addl_attrs: util.warn_deprecated( "The *addl_attrs on orm.undefer is deprecated. Please use " @@ -2246,19 +2392,23 @@ def undefer(key, *addl_attrs) -> _AbstractLoad: @loader_unbound_fn -def undefer_group(name) -> _AbstractLoad: +def undefer_group(name: str) -> _AbstractLoad: element = _WildcardLoad() return element.undefer_group(name) @loader_unbound_fn -def with_expression(key, expression) -> _AbstractLoad: +def with_expression( + key: _AttrType, expression: _ColumnExpressionArgument[Any] +) -> _AbstractLoad: return _generate_from_keys( Load.with_expression, (key,), False, {"expression": expression} ) @loader_unbound_fn -def selectin_polymorphic(base_cls, classes) -> _AbstractLoad: +def selectin_polymorphic( + base_cls: _EntityType[Any], classes: Iterable[Type[Any]] +) -> _AbstractLoad: ul = Load(base_cls) return ul.selectin_polymorphic(classes) |