diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-28 16:19:43 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-03 15:58:45 -0400 |
commit | 1fa3e2e3814b4d28deca7426bb3f36e7fb515496 (patch) | |
tree | 9b07b8437b1190227c2e8c51f2e942936721000f /lib/sqlalchemy/orm/attributes.py | |
parent | 6a496a5f40efe6d58b09eeca9320829789ceaa54 (diff) | |
download | sqlalchemy-1fa3e2e3814b4d28deca7426bb3f36e7fb515496.tar.gz |
pep484: attributes and related
also implements __slots__ for QueryableAttribute,
InstrumentedAttribute, Relationship.Comparator.
Change-Id: I47e823160706fc35a616f1179a06c7864089e5b5
Diffstat (limited to 'lib/sqlalchemy/orm/attributes.py')
-rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 925 |
1 files changed, 686 insertions, 239 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 9a6e94e22..9aeaeaa27 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -4,7 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +# mypy: allow-untyped-defs, allow-untyped-calls """Defines instrumentation for class attributes and their interaction with instances. @@ -17,16 +17,18 @@ defines a large part of the ORM's interactivity. from __future__ import annotations -from collections import namedtuple +import dataclasses import operator from typing import Any from typing import Callable -from typing import Collection +from typing import cast +from typing import ClassVar from typing import Dict from typing import List from typing import NamedTuple from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING @@ -36,6 +38,7 @@ from typing import Union from . import collections from . import exc as orm_exc from . import interfaces +from ._typing import insp_is_aliased_class from .base import ATTR_EMPTY from .base import ATTR_WAS_SET from .base import CALLABLES_OK @@ -45,6 +48,7 @@ from .base import instance_dict as instance_dict from .base import instance_state as instance_state from .base import instance_str from .base import LOAD_AGAINST_COMMITTED +from .base import LoaderCallableStatus from .base import manager_of_class as manager_of_class from .base import Mapped as Mapped # noqa from .base import NEVER_SET # noqa @@ -70,17 +74,41 @@ from .. import event from .. import exc from .. import inspection from .. import util +from ..event import dispatcher +from ..event import EventTarget from ..sql import base as sql_base from ..sql import cache_key +from ..sql import coercions from ..sql import roles -from ..sql import traversals from ..sql import visitors +from ..util.typing import Literal +from ..util.typing import TypeGuard if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _LoaderCallable + from ._typing import _O + from .collections import _AdaptedCollectionProtocol + from .collections import CollectionAdapter + from .dynamic import DynamicAttributeImpl from .interfaces import MapperProperty + from .relationships import Relationship from .state import InstanceState - from ..sql.dml import _DMLColumnElement + from .util import AliasedInsp + from ..event.base import _Dispatch + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _InfoType + from ..sql._typing import _PropagateAttrsType + from ..sql.annotation import _AnnotationDict from ..sql.elements import ColumnElement + from ..sql.elements import Label + from ..sql.operators import OperatorType + from ..sql.selectable import FromClause + _T = TypeVar("_T") @@ -89,19 +117,27 @@ class NoKey(str): pass +_AllPendingType = List[Tuple[Optional["InstanceState[Any]"], Optional[object]]] + NO_KEY = NoKey("no name") +SelfQueryableAttribute = TypeVar( + "SelfQueryableAttribute", bound="QueryableAttribute[Any]" +) + @inspection._self_inspects class QueryableAttribute( + roles.ExpressionElementRole[_T], interfaces._MappedAttribute[_T], interfaces.InspectionAttr, interfaces.PropComparator[_T], - traversals.HasCopyInternals, roles.JoinTargetRole, roles.OnClauseRole, sql_base.Immutable, - cache_key.MemoizedHasCacheKey, + cache_key.SlotsMemoizedHasCacheKey, + util.MemoizedSlots, + EventTarget, ): """Base class for :term:`descriptor` objects that intercept attribute events on behalf of a :class:`.MapperProperty` @@ -121,9 +157,33 @@ class QueryableAttribute( :attr:`_orm.Mapper.attrs` """ + __slots__ = ( + "class_", + "key", + "impl", + "comparator", + "property", + "parent", + "expression", + "_of_type", + "_extra_criteria", + "_slots_dispatch", + "_propagate_attrs", + "_doc", + ) + is_attribute = True + dispatch: dispatcher[QueryableAttribute[_T]] + + class_: _ExternalEntityType[Any] + key: str + parententity: _InternalEntityType[Any] impl: AttributeImpl + comparator: interfaces.PropComparator[_T] + _of_type: Optional[_InternalEntityType[Any]] + _extra_criteria: Tuple[ColumnElement[bool], ...] + _doc: Optional[str] # PropComparator has a __visit_name__ to participate within # traversals. Disambiguate the attribute vs. a comparator. @@ -131,21 +191,30 @@ class QueryableAttribute( def __init__( self, - class_, - key, - parententity, - impl=None, - comparator=None, - of_type=None, - extra_criteria=(), + class_: _ExternalEntityType[_O], + key: str, + parententity: _InternalEntityType[_O], + comparator: interfaces.PropComparator[_T], + impl: Optional[AttributeImpl] = None, + of_type: Optional[_InternalEntityType[Any]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), ): self.class_ = class_ self.key = key - self._parententity = parententity - self.impl = impl + + self._parententity = self.parent = parententity + + # this attribute is non-None after mappers are set up, however in the + # interim class manager setup, there's a check for None to see if it + # needs to be populated, so we assign None here leaving the attribute + # in a temporarily not-type-correct state + self.impl = impl # type: ignore + + assert comparator is not None self.comparator = comparator self._of_type = of_type self._extra_criteria = extra_criteria + self._doc = None manager = opt_manager_of_class(class_) # manager is None in the case of AliasedClass @@ -156,7 +225,7 @@ class QueryableAttribute( if key in base: self.dispatch._update(base[key].dispatch) if base[key].dispatch._active_history: - self.dispatch._active_history = True + self.dispatch._active_history = True # type: ignore _cache_key_traversal = [ ("key", visitors.ExtendedInternalTraversal.dp_string), @@ -165,7 +234,7 @@ class QueryableAttribute( ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), ] - def __reduce__(self): + def __reduce__(self) -> Any: # this method is only used in terms of the # sqlalchemy.ext.serializer extension return ( @@ -178,21 +247,19 @@ class QueryableAttribute( ), ) - @util.memoized_property - def _supports_population(self): - return self.impl.supports_population - @property - def _impl_uses_objects(self): + def _impl_uses_objects(self) -> bool: return self.impl.uses_objects - def get_history(self, instance, passive=PASSIVE_OFF): + def get_history( + self, instance: Any, passive: PassiveFlag = PASSIVE_OFF + ) -> History: return self.impl.get_history( instance_state(instance), instance_dict(instance), passive ) - @util.memoized_property - def info(self): + @property + def info(self) -> _InfoType: """Return the 'info' dictionary for the underlying SQL element. The behavior here is as follows: @@ -233,27 +300,28 @@ class QueryableAttribute( """ return self.comparator.info - @util.memoized_property - def parent(self): - """Return an inspection instance representing the parent. + parent: _InternalEntityType[Any] + """Return an inspection instance representing the parent. - This will be either an instance of :class:`_orm.Mapper` - or :class:`.AliasedInsp`, depending upon the nature - of the parent entity which this attribute is associated - with. + This will be either an instance of :class:`_orm.Mapper` + or :class:`.AliasedInsp`, depending upon the nature + of the parent entity which this attribute is associated + with. - """ - return inspection.inspect(self._parententity) + """ - @util.memoized_property - def expression(self): - """The SQL expression object represented by this - :class:`.QueryableAttribute`. + expression: ColumnElement[_T] + """The SQL expression object represented by this + :class:`.QueryableAttribute`. - This will typically be an instance of a :class:`_sql.ColumnElement` - subclass representing a column expression. + This will typically be an instance of a :class:`_sql.ColumnElement` + subclass representing a column expression. + + """ + + def _memoized_attr_expression(self) -> ColumnElement[_T]: + annotations: _AnnotationDict - """ if self.key is NO_KEY: annotations = {"entity_namespace": self._entity_namespace} else: @@ -265,6 +333,8 @@ class QueryableAttribute( ce = self.comparator.__clause_element__() try: + if TYPE_CHECKING: + assert isinstance(ce, ColumnElement) anno = ce._annotate except AttributeError as ae: raise exc.InvalidRequestError( @@ -275,29 +345,42 @@ class QueryableAttribute( else: return anno(annotations) + def _memoized_attr__propagate_attrs(self) -> _PropagateAttrsType: + # this suits the case in coercions where we don't actually + # call ``__clause_element__()`` but still need to get + # resolved._propagate_attrs. See #6558. + return util.immutabledict( + { + "compile_state_plugin": "orm", + "plugin_subject": self._parentmapper, + } + ) + @property - def _entity_namespace(self): + def _entity_namespace(self) -> _InternalEntityType[Any]: return self._parententity @property - def _annotations(self): + def _annotations(self) -> _AnnotationDict: return self.__clause_element__()._annotations def __clause_element__(self) -> ColumnElement[_T]: return self.expression @property - def _from_objects(self): + def _from_objects(self) -> List[FromClause]: return self.expression._from_objects def _bulk_update_tuples( self, value: Any - ) -> List[Tuple[_DMLColumnElement, Any]]: + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: """Return setter tuples for a bulk UPDATE.""" return self.comparator._bulk_update_tuples(value) - def adapt_to_entity(self, adapt_to_entity): + def adapt_to_entity( + self: SelfQueryableAttribute, adapt_to_entity: AliasedInsp[Any] + ) -> SelfQueryableAttribute: assert not self._of_type return self.__class__( adapt_to_entity.entity, @@ -307,7 +390,7 @@ class QueryableAttribute( parententity=adapt_to_entity, ) - def of_type(self, entity): + def of_type(self, entity: _EntityType[_T]) -> QueryableAttribute[_T]: return QueryableAttribute( self.class_, self.key, @@ -318,18 +401,28 @@ class QueryableAttribute( extra_criteria=self._extra_criteria, ) - def and_(self, *other): + def and_( + self, *clauses: _ColumnExpressionArgument[bool] + ) -> interfaces.PropComparator[bool]: + if TYPE_CHECKING: + assert isinstance(self.comparator, Relationship.Comparator) + + exprs = tuple( + coercions.expect(roles.WhereHavingRole, clause) + for clause in util.coerce_generator_arg(clauses) + ) + return QueryableAttribute( self.class_, self.key, self._parententity, impl=self.impl, - comparator=self.comparator.and_(*other), + comparator=self.comparator.and_(*exprs), of_type=self._of_type, - extra_criteria=self._extra_criteria + other, + extra_criteria=self._extra_criteria + exprs, ) - def _clone(self, **kw): + def _clone(self, **kw: Any) -> QueryableAttribute[_T]: return QueryableAttribute( self.class_, self.key, @@ -340,19 +433,30 @@ class QueryableAttribute( extra_criteria=self._extra_criteria, ) - def label(self, name): + def label(self, name: Optional[str]) -> Label[_T]: return self.__clause_element__().label(name) - def operate(self, op, *other, **kwargs): - return op(self.comparator, *other, **kwargs) + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501 - def reverse_operate(self, op, other, **kwargs): - return op(other, self.comparator, **kwargs) + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501 - def hasparent(self, state, optimistic=False): + def hasparent( + self, state: InstanceState[Any], optimistic: bool = False + ) -> bool: return self.impl.hasparent(state, optimistic=optimistic) is not False - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: + try: + return util.MemoizedSlots.__getattr__(self, key) + except AttributeError: + pass + try: return getattr(self.comparator, key) except AttributeError as err: @@ -367,27 +471,22 @@ class QueryableAttribute( ) ) from err - def __str__(self): - return "%s.%s" % (self.class_.__name__, self.key) + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" - @util.memoized_property - def property(self) -> MapperProperty[_T]: - """Return the :class:`.MapperProperty` associated with this - :class:`.QueryableAttribute`. - - - Return values here will commonly be instances of - :class:`.ColumnProperty` or :class:`.Relationship`. - - - """ + def _memoized_attr_property(self) -> Optional[MapperProperty[Any]]: return self.comparator.property -def _queryable_attribute_unreduce(key, mapped_class, parententity, entity): +def _queryable_attribute_unreduce( + key: str, + mapped_class: Type[_O], + parententity: _InternalEntityType[_O], + entity: _ExternalEntityType[Any], +) -> Any: # this method is only used in terms of the # sqlalchemy.ext.serializer extension - if parententity.is_aliased_class: + if insp_is_aliased_class(parententity): return entity._get_from_serialized(key, mapped_class, parententity) else: return getattr(entity, key) @@ -402,45 +501,60 @@ class InstrumentedAttribute(QueryableAttribute[_T]): """ + __slots__ = () + inherit_cache = True - def __set__(self, instance, value): + # if not TYPE_CHECKING: + + @property # type: ignore + def __doc__(self) -> Optional[str]: # type: ignore + return self._doc + + @__doc__.setter + def __doc__(self, value: Optional[str]) -> None: + self._doc = value + + def __set__(self, instance: object, value: Any) -> None: self.impl.set( instance_state(instance), instance_dict(instance), value, None ) - def __delete__(self, instance): + def __delete__(self, instance: object) -> None: self.impl.delete(instance_state(instance), instance_dict(instance)) @overload - def __get__( - self, instance: None, owner: Type[Any] - ) -> InstrumentedAttribute: + def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]: ... @overload - def __get__(self, instance: object, owner: Type[Any]) -> Optional[_T]: + def __get__(self, instance: object, owner: Any) -> _T: ... def __get__( - self, instance: Optional[object], owner: Type[Any] - ) -> Union[InstrumentedAttribute, Optional[_T]]: + self, instance: Optional[object], owner: Any + ) -> Union[InstrumentedAttribute[_T], _T]: if instance is None: return self dict_ = instance_dict(instance) - if self._supports_population and self.key in dict_: - return dict_[self.key] + if self.impl.supports_population and self.key in dict_: + return dict_[self.key] # type: ignore[no-any-return] else: try: state = instance_state(instance) except AttributeError as err: raise orm_exc.UnmappedInstanceError(instance) from err - return self.impl.get(state, dict_) + return self.impl.get(state, dict_) # type: ignore[no-any-return] -HasEntityNamespace = namedtuple("HasEntityNamespace", ["entity_namespace"]) -HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False +@dataclasses.dataclass(frozen=True) +class AdHocHasEntityNamespace: + # py37 compat, no slots=True on dataclass + __slots__ = ("entity_namespace",) + entity_namespace: _ExternalEntityType[Any] + is_mapper: ClassVar[bool] = False + is_aliased_class: ClassVar[bool] = False def create_proxied_attribute( @@ -455,7 +569,7 @@ def create_proxied_attribute( # TODO: can move this to descriptor_props if the need for this # function is removed from ext/hybrid.py - class Proxy(QueryableAttribute): + class Proxy(QueryableAttribute[Any]): """Presents the :class:`.QueryableAttribute` interface as a proxy on top of a Python descriptor / :class:`.PropComparator` combination. @@ -464,6 +578,10 @@ def create_proxied_attribute( _extra_criteria = () + # the attribute error catches inside of __getattr__ basically create a + # singularity if you try putting slots on this too + # __slots__ = ("descriptor", "original_property", "_comparator") + def __init__( self, class_, @@ -480,7 +598,15 @@ def create_proxied_attribute( self.original_property = original_property self._comparator = comparator self._adapt_to_entity = adapt_to_entity - self.__doc__ = doc + self._doc = self.__doc__ = doc + + @property + def _parententity(self): + return inspection.inspect(self.class_) + + @property + def parent(self): + return inspection.inspect(self.class_) _is_internal_proxy = True @@ -497,17 +623,13 @@ def create_proxied_attribute( ) @property - def _parententity(self): - return inspection.inspect(self.class_, raiseerr=False) - - @property def _entity_namespace(self): if hasattr(self._comparator, "_parententity"): return self._comparator._parententity else: # used by hybrid attributes which try to remain # agnostic of any ORM concepts like mappers - return HasEntityNamespace(self.class_) + return AdHocHasEntityNamespace(self.class_) @property def property(self): @@ -552,12 +674,22 @@ def create_proxied_attribute( else: return retval - def __str__(self): - return "%s.%s" % (self.class_.__name__, self.key) + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" def __getattr__(self, attribute): """Delegate __getattr__ to the original descriptor and/or comparator.""" + + # this is unfortunately very complicated, and is easily prone + # to recursion overflows when implementations of related + # __getattr__ schemes are changed + + try: + return util.MemoizedSlots.__getattr__(self, attribute) + except AttributeError: + pass + try: return getattr(descriptor, attribute) except AttributeError as err: @@ -602,7 +734,7 @@ OP_BULK_REPLACE = util.symbol("BULK_REPLACE") OP_MODIFIED = util.symbol("MODIFIED") -class AttributeEvent: +class AttributeEventToken: """A token propagated throughout the course of a chain of attribute events. @@ -619,7 +751,8 @@ class AttributeEvent: event handlers, and is used to control the propagation of operations across two mutually-dependent attributes. - .. versionadded:: 0.9.0 + .. versionchanged:: 2.0 Changed the name from ``AttributeEvent`` + to ``AttributeEventToken``. :attribute impl: The :class:`.AttributeImpl` which is the current event initiator. @@ -639,7 +772,7 @@ class AttributeEvent: def __eq__(self, other): return ( - isinstance(other, AttributeEvent) + isinstance(other, AttributeEventToken) and other.impl is self.impl and other.op == self.op ) @@ -652,28 +785,37 @@ class AttributeEvent: return self.impl.hasparent(state) -Event = AttributeEvent +AttributeEvent = AttributeEventToken # legacy +Event = AttributeEventToken # legacy class AttributeImpl: """internal implementation for instrumented attributes.""" collection: bool + default_accepts_scalar_loader: bool + uses_objects: bool + supports_population: bool + dynamic: bool + + _replace_token: AttributeEventToken + _remove_token: AttributeEventToken + _append_token: AttributeEventToken def __init__( self, - class_, - key, - callable_, - dispatch, - trackparent=False, - compare_function=None, - active_history=False, - parent_token=None, - load_on_unexpire=True, - send_modified_events=True, - accepts_scalar_loader=None, - **kwargs, + class_: _ExternalEntityType[_O], + key: str, + callable_: _LoaderCallable, + dispatch: _Dispatch[QueryableAttribute[Any]], + trackparent: bool = False, + compare_function: Optional[Callable[..., bool]] = None, + active_history: bool = False, + parent_token: Optional[AttributeEventToken] = None, + load_on_unexpire: bool = True, + send_modified_events: bool = True, + accepts_scalar_loader: Optional[bool] = None, + **kwargs: Any, ): r"""Construct an AttributeImpl. @@ -743,7 +885,7 @@ class AttributeImpl: self.dispatch._active_history = True self.load_on_unexpire = load_on_unexpire - self._modified_token = Event(self, OP_MODIFIED) + self._modified_token = AttributeEventToken(self, OP_MODIFIED) __slots__ = ( "class_", @@ -760,8 +902,8 @@ class AttributeImpl: "_deferred_history", ) - def __str__(self): - return "%s.%s" % (self.class_.__name__, self.key) + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" def _get_active_history(self): """Backwards compat for impl.active_history""" @@ -773,7 +915,9 @@ class AttributeImpl: active_history = property(_get_active_history, _set_active_history) - def hasparent(self, state, optimistic=False): + def hasparent( + self, state: InstanceState[Any], optimistic: bool = False + ) -> bool: """Return the boolean value of a `hasparent` flag attached to the given state. @@ -796,7 +940,12 @@ class AttributeImpl: state.parents.get(id(self.parent_token), optimistic) is not False ) - def sethasparent(self, state, parent_state, value): + def sethasparent( + self, + state: InstanceState[Any], + parent_state: InstanceState[Any], + value: bool, + ) -> None: """Set a boolean flag on the given item corresponding to whether or not it is attached to a parent object via the attribute represented by this ``InstrumentedAttribute``. @@ -839,11 +988,16 @@ class AttributeImpl: self, state: InstanceState[Any], dict_: _InstanceDict, - passive=PASSIVE_OFF, + passive: PassiveFlag = PASSIVE_OFF, ) -> History: raise NotImplementedError() - def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE): + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: """Return a list of tuples of (state, obj) for all objects in this attribute's current state + history. @@ -861,7 +1015,9 @@ class AttributeImpl: """ raise NotImplementedError() - def _default_value(self, state, dict_): + def _default_value( + self, state: InstanceState[Any], dict_: _InstanceDict + ) -> Any: """Produce an empty value for an uninitialized scalar attribute.""" assert self.key not in dict_, ( @@ -877,7 +1033,12 @@ class AttributeImpl: return value - def get(self, state, dict_, passive=PASSIVE_OFF): + def get( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Any: """Retrieve a value from the given object. If a callable is assembled on this object's attribute, and passive is False, the callable will be executed and the @@ -917,7 +1078,9 @@ class AttributeImpl: else: return self._default_value(state, dict_) - def _fire_loader_callables(self, state, key, passive): + def _fire_loader_callables( + self, state: InstanceState[Any], key: str, passive: PassiveFlag + ) -> Any: if ( self.accepts_scalar_loader and self.load_on_unexpire @@ -932,15 +1095,36 @@ class AttributeImpl: else: return ATTR_EMPTY - def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def append( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: self.set(state, dict_, value, initiator, passive=passive) - def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def remove( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: self.set( state, dict_, None, initiator, passive=passive, check_old=value ) - def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def pop( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: self.set( state, dict_, @@ -953,17 +1137,25 @@ class AttributeImpl: def set( self, - state, - dict_, - value, - initiator, - passive=PASSIVE_OFF, - check_old=None, - pop=False, - ): + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + ) -> None: + raise NotImplementedError() + + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: raise NotImplementedError() - def get_committed_value(self, state, dict_, passive=PASSIVE_OFF): + def get_committed_value( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Any: """return the unchanged value of this attribute""" if self.key in state.committed_state: @@ -996,10 +1188,12 @@ class ScalarAttributeImpl(AttributeImpl): def __init__(self, *arg, **kw): super(ScalarAttributeImpl, self).__init__(*arg, **kw) - self._replace_token = self._append_token = Event(self, OP_REPLACE) - self._remove_token = Event(self, OP_REMOVE) + self._replace_token = self._append_token = AttributeEventToken( + self, OP_REPLACE + ) + self._remove_token = AttributeEventToken(self, OP_REMOVE) - def delete(self, state, dict_): + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: if self.dispatch._active_history: old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE) else: @@ -1042,11 +1236,11 @@ class ScalarAttributeImpl(AttributeImpl): state: InstanceState[Any], dict_: Dict[str, Any], value: Any, - initiator: Optional[Event], + initiator: Optional[AttributeEventToken], passive: PassiveFlag = PASSIVE_OFF, check_old: Optional[object] = None, pop: bool = False, - ): + ) -> None: if self.dispatch._active_history: old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE) else: @@ -1059,21 +1253,30 @@ class ScalarAttributeImpl(AttributeImpl): state._modified_event(dict_, self, old) dict_[self.key] = value - def fire_replace_event(self, state, dict_, value, previous, initiator): + def fire_replace_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + previous: Any, + initiator: Optional[AttributeEventToken], + ) -> _T: for fn in self.dispatch.set: value = fn( state, value, previous, initiator or self._replace_token ) return value - def fire_remove_event(self, state, dict_, value, initiator): + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + ) -> None: for fn in self.dispatch.remove: fn(state, value, initiator or self._remove_token) - @property - def type(self): - self.property.columns[0].type - class ScalarObjectAttributeImpl(ScalarAttributeImpl): """represents a scalar-holding InstrumentedAttribute, @@ -1090,7 +1293,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): __slots__ = () - def delete(self, state, dict_): + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: if self.dispatch._active_history: old = self.get( state, @@ -1122,7 +1325,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): ): raise AttributeError("%s object does not have a value" % self) - def get_history(self, state, dict_, passive=PASSIVE_OFF): + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: if self.key in dict_: current = dict_[self.key] else: @@ -1152,7 +1360,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): self, state, current, original=original ) - def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE): + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: if self.key in dict_: current = dict_[self.key] elif passive & CALLABLES_OK: @@ -1160,6 +1373,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): else: return [] + ret: _AllPendingType + # can't use __hash__(), can't use __eq__() here if ( current is not None @@ -1184,14 +1399,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): def set( self, - state, - dict_, - value, - initiator, - passive=PASSIVE_OFF, - check_old=None, - pop=False, - ): + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + ) -> None: """Set a value on the given InstanceState.""" if self.dispatch._active_history: @@ -1227,7 +1442,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): value = self.fire_replace_event(state, dict_, value, old, initiator) dict_[self.key] = value - def fire_remove_event(self, state, dict_, value, initiator): + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + ) -> None: if self.trackparent and value not in ( None, PASSIVE_NO_RESULT, @@ -1240,7 +1461,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): state._modified_event(dict_, self, value) - def fire_replace_event(self, state, dict_, value, previous, initiator): + def fire_replace_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + previous: Any, + initiator: Optional[AttributeEventToken], + ) -> _T: if self.trackparent: if previous is not value and previous not in ( None, @@ -1263,7 +1491,64 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): return value -class CollectionAttributeImpl(AttributeImpl): +class HasCollectionAdapter: + __slots__ = () + + def _dispose_previous_collection( + self, + state: InstanceState[Any], + collection: _AdaptedCollectionProtocol, + adapter: CollectionAdapter, + fire_event: bool, + ) -> None: + raise NotImplementedError() + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + raise NotImplementedError() + + +if TYPE_CHECKING: + + def _is_collection_attribute_impl( + impl: AttributeImpl, + ) -> TypeGuard[CollectionAttributeImpl]: + ... + +else: + _is_collection_attribute_impl = operator.attrgetter("collection") + + +class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): """A collection-holding attribute that instruments changes in membership. Only handles collections of instrumented objects. @@ -1275,12 +1560,14 @@ class CollectionAttributeImpl(AttributeImpl): """ - default_accepts_scalar_loader = False uses_objects = True - supports_population = True collection = True + default_accepts_scalar_loader = False + supports_population = True dynamic = False + _bulk_replace_token: AttributeEventToken + __slots__ = ( "copy", "collection_factory", @@ -1316,9 +1603,9 @@ class CollectionAttributeImpl(AttributeImpl): copy_function = self.__copy self.copy = copy_function self.collection_factory = typecallable - self._append_token = Event(self, OP_APPEND) - self._remove_token = Event(self, OP_REMOVE) - self._bulk_replace_token = Event(self, OP_BULK_REPLACE) + self._append_token = AttributeEventToken(self, OP_APPEND) + self._remove_token = AttributeEventToken(self, OP_REMOVE) + self._bulk_replace_token = AttributeEventToken(self, OP_BULK_REPLACE) self._duck_typed_as = util.duck_type_collection( self.collection_factory() ) @@ -1336,14 +1623,24 @@ class CollectionAttributeImpl(AttributeImpl): def __copy(self, item): return [y for y in collections.collection_adapter(item)] - def get_history(self, state, dict_, passive=PASSIVE_OFF): + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: current = self.get(state, dict_, passive=passive) if current is PASSIVE_NO_RESULT: return HISTORY_BLANK else: return History.from_collection(self, state, current) - def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE): + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: # NOTE: passive is ignored here at the moment if self.key not in dict_: @@ -1383,7 +1680,13 @@ class CollectionAttributeImpl(AttributeImpl): return [(instance_state(o), o) for o in current] - def fire_append_event(self, state, dict_, value, initiator): + def fire_append_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + initiator: Optional[AttributeEventToken], + ) -> _T: for fn in self.dispatch.append: value = fn(state, value, initiator or self._append_token) @@ -1394,13 +1697,24 @@ class CollectionAttributeImpl(AttributeImpl): return value - def fire_append_wo_mutation_event(self, state, dict_, value, initiator): + def fire_append_wo_mutation_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + initiator: Optional[AttributeEventToken], + ) -> _T: for fn in self.dispatch.append_wo_mutation: value = fn(state, value, initiator or self._append_token) return value - def fire_pre_remove_event(self, state, dict_, initiator): + def fire_pre_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + initiator: Optional[AttributeEventToken], + ) -> None: """A special event used for pop() operations. The "remove" event needs to have the item to be removed passed to @@ -1411,7 +1725,13 @@ class CollectionAttributeImpl(AttributeImpl): """ state._modified_event(dict_, self, NO_VALUE, True) - def fire_remove_event(self, state, dict_, value, initiator): + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + ) -> None: if self.trackparent and value is not None: self.sethasparent(instance_state(value), state, False) @@ -1420,7 +1740,7 @@ class CollectionAttributeImpl(AttributeImpl): state._modified_event(dict_, self, NO_VALUE, True) - def delete(self, state, dict_): + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: if self.key not in dict_: return @@ -1433,7 +1753,9 @@ class CollectionAttributeImpl(AttributeImpl): # del is a no-op if collection not present. del dict_[self.key] - def _default_value(self, state, dict_): + def _default_value( + self, state: InstanceState[Any], dict_: _InstanceDict + ) -> _AdaptedCollectionProtocol: """Produce an empty collection for an un-initialized attribute""" assert self.key not in dict_, ( @@ -1448,7 +1770,9 @@ class CollectionAttributeImpl(AttributeImpl): adapter._set_empty(user_data) return user_data - def _initialize_collection(self, state): + def _initialize_collection( + self, state: InstanceState[Any] + ) -> Tuple[CollectionAdapter, _AdaptedCollectionProtocol]: adapter, collection = state.manager.initialize_collection( self.key, state, self.collection_factory @@ -1458,7 +1782,14 @@ class CollectionAttributeImpl(AttributeImpl): return adapter, collection - def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def append( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: collection = self.get_collection(state, dict_, passive=passive) if collection is PASSIVE_NO_RESULT: value = self.fire_append_event(state, dict_, value, initiator) @@ -1467,9 +1798,18 @@ class CollectionAttributeImpl(AttributeImpl): ), "Collection was loaded during event handling." state._get_pending_mutation(self.key).append(value) else: + if TYPE_CHECKING: + assert isinstance(collection, CollectionAdapter) collection.append_with_event(value, initiator) - def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def remove( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: collection = self.get_collection(state, state.dict, passive=passive) if collection is PASSIVE_NO_RESULT: self.fire_remove_event(state, dict_, value, initiator) @@ -1478,9 +1818,18 @@ class CollectionAttributeImpl(AttributeImpl): ), "Collection was loaded during event handling." state._get_pending_mutation(self.key).remove(value) else: + if TYPE_CHECKING: + assert isinstance(collection, CollectionAdapter) collection.remove_with_event(value, initiator) - def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def pop( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: try: # TODO: better solution here would be to add # a "popper" role to collections.py to complement @@ -1491,15 +1840,15 @@ class CollectionAttributeImpl(AttributeImpl): def set( self, - state, - dict_, - value, - initiator=None, - passive=PASSIVE_OFF, - check_old=None, - pop=False, - _adapt=True, - ): + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + _adapt: bool = True, + ) -> None: iterable = orig_iterable = value # pulling a new collection first so that an adaptation exception does @@ -1518,7 +1867,7 @@ class CollectionAttributeImpl(AttributeImpl): and "None" or iterable.__class__.__name__ ) - wanted = self._duck_typed_as.__name__ + wanted = self._duck_typed_as.__name__ # type: ignore raise TypeError( "Incompatible collection type: %s is not %s-like" % (given, wanted) @@ -1560,8 +1909,12 @@ class CollectionAttributeImpl(AttributeImpl): self._dispose_previous_collection(state, old, old_collection, True) def _dispose_previous_collection( - self, state, collection, adapter, fire_event - ): + self, + state: InstanceState[Any], + collection: _AdaptedCollectionProtocol, + adapter: CollectionAdapter, + fire_event: bool, + ) -> None: del collection._sa_adapter # discarding old collection make sure it is not referenced in empty @@ -1570,11 +1923,15 @@ class CollectionAttributeImpl(AttributeImpl): if fire_event: self.dispatch.dispose_collection(state, collection, adapter) - def _invalidate_collection(self, collection: Collection) -> None: + def _invalidate_collection( + self, collection: _AdaptedCollectionProtocol + ) -> None: adapter = getattr(collection, "_sa_adapter") adapter.invalidated = True - def set_committed_value(self, state, dict_, value): + def set_committed_value( + self, state: InstanceState[Any], dict_: _InstanceDict, value: Any + ) -> _AdaptedCollectionProtocol: """Set an attribute value on the given instance and 'commit' it.""" collection, user_data = self._initialize_collection(state) @@ -1601,9 +1958,37 @@ class CollectionAttributeImpl(AttributeImpl): return user_data + @overload def get_collection( - self, state, dict_, user_data=None, passive=PASSIVE_OFF - ): + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: + ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: """Retrieve the CollectionAdapter associated with the given state. if user_data is None, retrieves it from the state using normal @@ -1612,14 +1997,18 @@ class CollectionAttributeImpl(AttributeImpl): """ if user_data is None: - user_data = self.get(state, dict_, passive=passive) - if user_data is PASSIVE_NO_RESULT: - return user_data + fetch_user_data = self.get(state, dict_, passive=passive) + if fetch_user_data is LoaderCallableStatus.PASSIVE_NO_RESULT: + return fetch_user_data + else: + user_data = cast("_AdaptedCollectionProtocol", fetch_user_data) return user_data._sa_adapter -def backref_listeners(attribute, key, uselist): +def backref_listeners( + attribute: QueryableAttribute[Any], key: str, uselist: bool +) -> None: """Apply listeners to synchronize a two-way relationship.""" # use easily recognizable names for stack traces. @@ -1695,7 +2084,7 @@ def backref_listeners(attribute, key, uselist): check_append_token = child_impl._append_token check_bulk_replace_token = ( child_impl._bulk_replace_token - if child_impl.collection + if _is_collection_attribute_impl(child_impl) else None ) @@ -1728,7 +2117,9 @@ def backref_listeners(attribute, key, uselist): # tokens to test for a recursive loop. check_append_token = child_impl._append_token check_bulk_replace_token = ( - child_impl._bulk_replace_token if child_impl.collection else None + child_impl._bulk_replace_token + if _is_collection_attribute_impl(child_impl) + else None ) if ( @@ -1756,6 +2147,8 @@ def backref_listeners(attribute, key, uselist): ) child_impl = child_state.manager[key].impl + check_replace_token: Optional[AttributeEventToken] + # tokens to test for a recursive loop. if not child_impl.collection and not child_impl.dynamic: check_remove_token = child_impl._remove_token @@ -1765,7 +2158,7 @@ def backref_listeners(attribute, key, uselist): check_remove_token = child_impl._remove_token check_replace_token = ( child_impl._bulk_replace_token - if child_impl.collection + if _is_collection_attribute_impl(child_impl) else None ) check_for_dupes_on_remove = False @@ -1848,10 +2241,10 @@ class History(NamedTuple): unchanged: Union[Tuple[()], List[Any]] deleted: Union[Tuple[()], List[Any]] - def __bool__(self): + def __bool__(self) -> bool: return self != HISTORY_BLANK - def empty(self): + def empty(self) -> bool: """Return True if this :class:`.History` has no changes and no existing, unchanged state. @@ -1859,29 +2252,29 @@ class History(NamedTuple): return not bool((self.added or self.deleted) or self.unchanged) - def sum(self): + def sum(self) -> Sequence[Any]: """Return a collection of added + unchanged + deleted.""" return ( (self.added or []) + (self.unchanged or []) + (self.deleted or []) ) - def non_deleted(self): + def non_deleted(self) -> Sequence[Any]: """Return a collection of added + unchanged.""" return (self.added or []) + (self.unchanged or []) - def non_added(self): + def non_added(self) -> Sequence[Any]: """Return a collection of unchanged + deleted.""" return (self.unchanged or []) + (self.deleted or []) - def has_changes(self): + def has_changes(self) -> bool: """Return True if this :class:`.History` has changes.""" return bool(self.added or self.deleted) - def as_state(self): + def as_state(self) -> History: return History( [ (c is not None) and instance_state(c) or None @@ -1898,9 +2291,16 @@ class History(NamedTuple): ) @classmethod - def from_scalar_attribute(cls, attribute, state, current): + def from_scalar_attribute( + cls, + attribute: ScalarAttributeImpl, + state: InstanceState[Any], + current: Any, + ) -> History: original = state.committed_state.get(attribute.key, _NO_HISTORY) + deleted: Union[Tuple[()], List[Any]] + if original is _NO_HISTORY: if current is NO_VALUE: return cls((), (), ()) @@ -1933,8 +2333,14 @@ class History(NamedTuple): @classmethod def from_object_attribute( - cls, attribute, state, current, original=_NO_HISTORY - ): + cls, + attribute: ScalarObjectAttributeImpl, + state: InstanceState[Any], + current: Any, + original: Any = _NO_HISTORY, + ) -> History: + deleted: Union[Tuple[()], List[Any]] + if original is _NO_HISTORY: original = state.committed_state.get(attribute.key, _NO_HISTORY) @@ -1965,7 +2371,12 @@ class History(NamedTuple): return cls([current], (), deleted) @classmethod - def from_collection(cls, attribute, state, current): + def from_collection( + cls, + attribute: CollectionAttributeImpl, + state: InstanceState[Any], + current: Any, + ) -> History: original = state.committed_state.get(attribute.key, _NO_HISTORY) if current is NO_VALUE: return cls((), (), ()) @@ -1999,7 +2410,9 @@ class History(NamedTuple): HISTORY_BLANK = History((), (), ()) -def get_history(obj, key, passive=PASSIVE_OFF): +def get_history( + obj: object, key: str, passive: PassiveFlag = PASSIVE_OFF +) -> History: """Return a :class:`.History` record for the given object and attribute key. @@ -2037,36 +2450,47 @@ def get_history(obj, key, passive=PASSIVE_OFF): return get_state_history(instance_state(obj), key, passive) -def get_state_history(state, key, passive=PASSIVE_OFF): +def get_state_history( + state: InstanceState[Any], key: str, passive: PassiveFlag = PASSIVE_OFF +) -> History: return state.get_history(key, passive) -def has_parent(cls, obj, key, optimistic=False): +def has_parent( + cls: Type[_O], obj: _O, key: str, optimistic: bool = False +) -> bool: """TODO""" manager = manager_of_class(cls) state = instance_state(obj) return manager.has_parent(state, key, optimistic) -def register_attribute(class_, key, **kw): - comparator = kw.pop("comparator", None) - parententity = kw.pop("parententity", None) - doc = kw.pop("doc", None) - desc = register_descriptor(class_, key, comparator, parententity, doc=doc) +def register_attribute( + class_: Type[_O], + key: str, + *, + comparator: interfaces.PropComparator[_T], + parententity: _InternalEntityType[_O], + doc: Optional[str] = None, + **kw: Any, +) -> InstrumentedAttribute[_T]: + desc = register_descriptor( + class_, key, comparator=comparator, parententity=parententity, doc=doc + ) register_attribute_impl(class_, key, **kw) return desc def register_attribute_impl( - class_, - key, - uselist=False, - callable_=None, - useobject=False, - impl_class=None, - backref=None, - **kw, -): + class_: Type[_O], + key: str, + uselist: bool = False, + callable_: Optional[_LoaderCallable] = None, + useobject: bool = False, + impl_class: Optional[Type[AttributeImpl]] = None, + backref: Optional[str] = None, + **kw: Any, +) -> InstrumentedAttribute[Any]: manager = manager_of_class(class_) if uselist: @@ -2077,10 +2501,18 @@ def register_attribute_impl( else: typecallable = kw.pop("typecallable", None) - dispatch = manager[key].dispatch + dispatch = cast( + "_Dispatch[QueryableAttribute[Any]]", manager[key].dispatch + ) # noqa: E501 + + impl: AttributeImpl if impl_class: - impl = impl_class(class_, key, typecallable, dispatch, **kw) + # TODO: this appears to be the DynamicAttributeImpl constructor + # which is hardcoded + impl = cast("Type[DynamicAttributeImpl]", impl_class)( + class_, key, typecallable, dispatch, **kw + ) elif uselist: impl = CollectionAttributeImpl( class_, key, callable_, dispatch, typecallable=typecallable, **kw @@ -2102,8 +2534,13 @@ def register_attribute_impl( def register_descriptor( - class_, key, comparator=None, parententity=None, doc=None -): + class_: Type[Any], + key: str, + *, + comparator: interfaces.PropComparator[_T], + parententity: _InternalEntityType[Any], + doc: Optional[str] = None, +) -> InstrumentedAttribute[_T]: manager = manager_of_class(class_) descriptor = InstrumentedAttribute( @@ -2116,11 +2553,11 @@ def register_descriptor( return descriptor -def unregister_attribute(class_, key): +def unregister_attribute(class_: Type[Any], key: str) -> None: manager_of_class(class_).uninstrument_attribute(key) -def init_collection(obj, key): +def init_collection(obj: object, key: str) -> CollectionAdapter: """Initialize a collection attribute and return the collection adapter. This function is used to provide direct access to collection internals @@ -2143,7 +2580,9 @@ def init_collection(obj, key): return init_state_collection(state, dict_, key) -def init_state_collection(state, dict_, key): +def init_state_collection( + state: InstanceState[Any], dict_: _InstanceDict, key: str +) -> CollectionAdapter: """Initialize a collection attribute and return the collection adapter. Discards any existing collection which may be there. @@ -2151,6 +2590,9 @@ def init_state_collection(state, dict_, key): """ attr = state.manager[key].impl + if TYPE_CHECKING: + assert isinstance(attr, HasCollectionAdapter) + old = dict_.pop(key, None) # discard old collection if old is not None: old_collection = old._sa_adapter @@ -2182,7 +2624,12 @@ def set_committed_value(instance, key, value): state.manager[key].impl.set_committed_value(state, dict_, value) -def set_attribute(instance, key, value, initiator=None): +def set_attribute( + instance: object, + key: str, + value: Any, + initiator: Optional[AttributeEventToken] = None, +) -> None: """Set the value of an attribute, firing history events. This function may be used regardless of instrumentation @@ -2211,7 +2658,7 @@ def set_attribute(instance, key, value, initiator=None): state.manager[key].impl.set(state, dict_, value, initiator) -def get_attribute(instance, key): +def get_attribute(instance: object, key: str) -> Any: """Get the value of an attribute, firing any callables required. This function may be used regardless of instrumentation @@ -2225,7 +2672,7 @@ def get_attribute(instance, key): return state.manager[key].impl.get(state, dict_) -def del_attribute(instance, key): +def del_attribute(instance: object, key: str) -> None: """Delete the value of an attribute, firing history events. This function may be used regardless of instrumentation @@ -2239,7 +2686,7 @@ def del_attribute(instance, key): state.manager[key].impl.delete(state, dict_) -def flag_modified(instance, key): +def flag_modified(instance: object, key: str) -> None: """Mark an attribute on an instance as 'modified'. This sets the 'modified' flag on the instance and @@ -2262,7 +2709,7 @@ def flag_modified(instance, key): state._modified_event(dict_, impl, NO_VALUE, is_userland=True) -def flag_dirty(instance): +def flag_dirty(instance: object) -> None: """Mark an instance as 'dirty' without any specific attribute mentioned. This is a special operation that will allow the object to travel through |