summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/attributes.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-28 16:19:43 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-05-03 15:58:45 -0400
commit1fa3e2e3814b4d28deca7426bb3f36e7fb515496 (patch)
tree9b07b8437b1190227c2e8c51f2e942936721000f /lib/sqlalchemy/orm/attributes.py
parent6a496a5f40efe6d58b09eeca9320829789ceaa54 (diff)
downloadsqlalchemy-1fa3e2e3814b4d28deca7426bb3f36e7fb515496.tar.gz
pep484: attributes and related
also implements __slots__ for QueryableAttribute, InstrumentedAttribute, Relationship.Comparator. Change-Id: I47e823160706fc35a616f1179a06c7864089e5b5
Diffstat (limited to 'lib/sqlalchemy/orm/attributes.py')
-rw-r--r--lib/sqlalchemy/orm/attributes.py925
1 files changed, 686 insertions, 239 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 9a6e94e22..9aeaeaa27 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -4,7 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
"""Defines instrumentation for class attributes and their interaction
with instances.
@@ -17,16 +17,18 @@ defines a large part of the ORM's interactivity.
from __future__ import annotations
-from collections import namedtuple
+import dataclasses
import operator
from typing import Any
from typing import Callable
-from typing import Collection
+from typing import cast
+from typing import ClassVar
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import overload
+from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
@@ -36,6 +38,7 @@ from typing import Union
from . import collections
from . import exc as orm_exc
from . import interfaces
+from ._typing import insp_is_aliased_class
from .base import ATTR_EMPTY
from .base import ATTR_WAS_SET
from .base import CALLABLES_OK
@@ -45,6 +48,7 @@ from .base import instance_dict as instance_dict
from .base import instance_state as instance_state
from .base import instance_str
from .base import LOAD_AGAINST_COMMITTED
+from .base import LoaderCallableStatus
from .base import manager_of_class as manager_of_class
from .base import Mapped as Mapped # noqa
from .base import NEVER_SET # noqa
@@ -70,17 +74,41 @@ from .. import event
from .. import exc
from .. import inspection
from .. import util
+from ..event import dispatcher
+from ..event import EventTarget
from ..sql import base as sql_base
from ..sql import cache_key
+from ..sql import coercions
from ..sql import roles
-from ..sql import traversals
from ..sql import visitors
+from ..util.typing import Literal
+from ..util.typing import TypeGuard
if TYPE_CHECKING:
+ from ._typing import _EntityType
+ from ._typing import _ExternalEntityType
+ from ._typing import _InstanceDict
+ from ._typing import _InternalEntityType
+ from ._typing import _LoaderCallable
+ from ._typing import _O
+ from .collections import _AdaptedCollectionProtocol
+ from .collections import CollectionAdapter
+ from .dynamic import DynamicAttributeImpl
from .interfaces import MapperProperty
+ from .relationships import Relationship
from .state import InstanceState
- from ..sql.dml import _DMLColumnElement
+ from .util import AliasedInsp
+ from ..event.base import _Dispatch
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _DMLColumnArgument
+ from ..sql._typing import _InfoType
+ from ..sql._typing import _PropagateAttrsType
+ from ..sql.annotation import _AnnotationDict
from ..sql.elements import ColumnElement
+ from ..sql.elements import Label
+ from ..sql.operators import OperatorType
+ from ..sql.selectable import FromClause
+
_T = TypeVar("_T")
@@ -89,19 +117,27 @@ class NoKey(str):
pass
+_AllPendingType = List[Tuple[Optional["InstanceState[Any]"], Optional[object]]]
+
NO_KEY = NoKey("no name")
+SelfQueryableAttribute = TypeVar(
+ "SelfQueryableAttribute", bound="QueryableAttribute[Any]"
+)
+
@inspection._self_inspects
class QueryableAttribute(
+ roles.ExpressionElementRole[_T],
interfaces._MappedAttribute[_T],
interfaces.InspectionAttr,
interfaces.PropComparator[_T],
- traversals.HasCopyInternals,
roles.JoinTargetRole,
roles.OnClauseRole,
sql_base.Immutable,
- cache_key.MemoizedHasCacheKey,
+ cache_key.SlotsMemoizedHasCacheKey,
+ util.MemoizedSlots,
+ EventTarget,
):
"""Base class for :term:`descriptor` objects that intercept
attribute events on behalf of a :class:`.MapperProperty`
@@ -121,9 +157,33 @@ class QueryableAttribute(
:attr:`_orm.Mapper.attrs`
"""
+ __slots__ = (
+ "class_",
+ "key",
+ "impl",
+ "comparator",
+ "property",
+ "parent",
+ "expression",
+ "_of_type",
+ "_extra_criteria",
+ "_slots_dispatch",
+ "_propagate_attrs",
+ "_doc",
+ )
+
is_attribute = True
+ dispatch: dispatcher[QueryableAttribute[_T]]
+
+ class_: _ExternalEntityType[Any]
+ key: str
+ parententity: _InternalEntityType[Any]
impl: AttributeImpl
+ comparator: interfaces.PropComparator[_T]
+ _of_type: Optional[_InternalEntityType[Any]]
+ _extra_criteria: Tuple[ColumnElement[bool], ...]
+ _doc: Optional[str]
# PropComparator has a __visit_name__ to participate within
# traversals. Disambiguate the attribute vs. a comparator.
@@ -131,21 +191,30 @@ class QueryableAttribute(
def __init__(
self,
- class_,
- key,
- parententity,
- impl=None,
- comparator=None,
- of_type=None,
- extra_criteria=(),
+ class_: _ExternalEntityType[_O],
+ key: str,
+ parententity: _InternalEntityType[_O],
+ comparator: interfaces.PropComparator[_T],
+ impl: Optional[AttributeImpl] = None,
+ of_type: Optional[_InternalEntityType[Any]] = None,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
):
self.class_ = class_
self.key = key
- self._parententity = parententity
- self.impl = impl
+
+ self._parententity = self.parent = parententity
+
+ # this attribute is non-None after mappers are set up, however in the
+ # interim class manager setup, there's a check for None to see if it
+ # needs to be populated, so we assign None here leaving the attribute
+ # in a temporarily not-type-correct state
+ self.impl = impl # type: ignore
+
+ assert comparator is not None
self.comparator = comparator
self._of_type = of_type
self._extra_criteria = extra_criteria
+ self._doc = None
manager = opt_manager_of_class(class_)
# manager is None in the case of AliasedClass
@@ -156,7 +225,7 @@ class QueryableAttribute(
if key in base:
self.dispatch._update(base[key].dispatch)
if base[key].dispatch._active_history:
- self.dispatch._active_history = True
+ self.dispatch._active_history = True # type: ignore
_cache_key_traversal = [
("key", visitors.ExtendedInternalTraversal.dp_string),
@@ -165,7 +234,7 @@ class QueryableAttribute(
("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
]
- def __reduce__(self):
+ def __reduce__(self) -> Any:
# this method is only used in terms of the
# sqlalchemy.ext.serializer extension
return (
@@ -178,21 +247,19 @@ class QueryableAttribute(
),
)
- @util.memoized_property
- def _supports_population(self):
- return self.impl.supports_population
-
@property
- def _impl_uses_objects(self):
+ def _impl_uses_objects(self) -> bool:
return self.impl.uses_objects
- def get_history(self, instance, passive=PASSIVE_OFF):
+ def get_history(
+ self, instance: Any, passive: PassiveFlag = PASSIVE_OFF
+ ) -> History:
return self.impl.get_history(
instance_state(instance), instance_dict(instance), passive
)
- @util.memoized_property
- def info(self):
+ @property
+ def info(self) -> _InfoType:
"""Return the 'info' dictionary for the underlying SQL element.
The behavior here is as follows:
@@ -233,27 +300,28 @@ class QueryableAttribute(
"""
return self.comparator.info
- @util.memoized_property
- def parent(self):
- """Return an inspection instance representing the parent.
+ parent: _InternalEntityType[Any]
+ """Return an inspection instance representing the parent.
- This will be either an instance of :class:`_orm.Mapper`
- or :class:`.AliasedInsp`, depending upon the nature
- of the parent entity which this attribute is associated
- with.
+ This will be either an instance of :class:`_orm.Mapper`
+ or :class:`.AliasedInsp`, depending upon the nature
+ of the parent entity which this attribute is associated
+ with.
- """
- return inspection.inspect(self._parententity)
+ """
- @util.memoized_property
- def expression(self):
- """The SQL expression object represented by this
- :class:`.QueryableAttribute`.
+ expression: ColumnElement[_T]
+ """The SQL expression object represented by this
+ :class:`.QueryableAttribute`.
- This will typically be an instance of a :class:`_sql.ColumnElement`
- subclass representing a column expression.
+ This will typically be an instance of a :class:`_sql.ColumnElement`
+ subclass representing a column expression.
+
+ """
+
+ def _memoized_attr_expression(self) -> ColumnElement[_T]:
+ annotations: _AnnotationDict
- """
if self.key is NO_KEY:
annotations = {"entity_namespace": self._entity_namespace}
else:
@@ -265,6 +333,8 @@ class QueryableAttribute(
ce = self.comparator.__clause_element__()
try:
+ if TYPE_CHECKING:
+ assert isinstance(ce, ColumnElement)
anno = ce._annotate
except AttributeError as ae:
raise exc.InvalidRequestError(
@@ -275,29 +345,42 @@ class QueryableAttribute(
else:
return anno(annotations)
+ def _memoized_attr__propagate_attrs(self) -> _PropagateAttrsType:
+ # this suits the case in coercions where we don't actually
+ # call ``__clause_element__()`` but still need to get
+ # resolved._propagate_attrs. See #6558.
+ return util.immutabledict(
+ {
+ "compile_state_plugin": "orm",
+ "plugin_subject": self._parentmapper,
+ }
+ )
+
@property
- def _entity_namespace(self):
+ def _entity_namespace(self) -> _InternalEntityType[Any]:
return self._parententity
@property
- def _annotations(self):
+ def _annotations(self) -> _AnnotationDict:
return self.__clause_element__()._annotations
def __clause_element__(self) -> ColumnElement[_T]:
return self.expression
@property
- def _from_objects(self):
+ def _from_objects(self) -> List[FromClause]:
return self.expression._from_objects
def _bulk_update_tuples(
self, value: Any
- ) -> List[Tuple[_DMLColumnElement, Any]]:
+ ) -> Sequence[Tuple[_DMLColumnArgument, Any]]:
"""Return setter tuples for a bulk UPDATE."""
return self.comparator._bulk_update_tuples(value)
- def adapt_to_entity(self, adapt_to_entity):
+ def adapt_to_entity(
+ self: SelfQueryableAttribute, adapt_to_entity: AliasedInsp[Any]
+ ) -> SelfQueryableAttribute:
assert not self._of_type
return self.__class__(
adapt_to_entity.entity,
@@ -307,7 +390,7 @@ class QueryableAttribute(
parententity=adapt_to_entity,
)
- def of_type(self, entity):
+ def of_type(self, entity: _EntityType[_T]) -> QueryableAttribute[_T]:
return QueryableAttribute(
self.class_,
self.key,
@@ -318,18 +401,28 @@ class QueryableAttribute(
extra_criteria=self._extra_criteria,
)
- def and_(self, *other):
+ def and_(
+ self, *clauses: _ColumnExpressionArgument[bool]
+ ) -> interfaces.PropComparator[bool]:
+ if TYPE_CHECKING:
+ assert isinstance(self.comparator, Relationship.Comparator)
+
+ exprs = tuple(
+ coercions.expect(roles.WhereHavingRole, clause)
+ for clause in util.coerce_generator_arg(clauses)
+ )
+
return QueryableAttribute(
self.class_,
self.key,
self._parententity,
impl=self.impl,
- comparator=self.comparator.and_(*other),
+ comparator=self.comparator.and_(*exprs),
of_type=self._of_type,
- extra_criteria=self._extra_criteria + other,
+ extra_criteria=self._extra_criteria + exprs,
)
- def _clone(self, **kw):
+ def _clone(self, **kw: Any) -> QueryableAttribute[_T]:
return QueryableAttribute(
self.class_,
self.key,
@@ -340,19 +433,30 @@ class QueryableAttribute(
extra_criteria=self._extra_criteria,
)
- def label(self, name):
+ def label(self, name: Optional[str]) -> Label[_T]:
return self.__clause_element__().label(name)
- def operate(self, op, *other, **kwargs):
- return op(self.comparator, *other, **kwargs)
+ def operate(
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501
- def reverse_operate(self, op, other, **kwargs):
- return op(other, self.comparator, **kwargs)
+ def reverse_operate(
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501
- def hasparent(self, state, optimistic=False):
+ def hasparent(
+ self, state: InstanceState[Any], optimistic: bool = False
+ ) -> bool:
return self.impl.hasparent(state, optimistic=optimistic) is not False
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
+ try:
+ return util.MemoizedSlots.__getattr__(self, key)
+ except AttributeError:
+ pass
+
try:
return getattr(self.comparator, key)
except AttributeError as err:
@@ -367,27 +471,22 @@ class QueryableAttribute(
)
) from err
- def __str__(self):
- return "%s.%s" % (self.class_.__name__, self.key)
+ def __str__(self) -> str:
+ return f"{self.class_.__name__}.{self.key}"
- @util.memoized_property
- def property(self) -> MapperProperty[_T]:
- """Return the :class:`.MapperProperty` associated with this
- :class:`.QueryableAttribute`.
-
-
- Return values here will commonly be instances of
- :class:`.ColumnProperty` or :class:`.Relationship`.
-
-
- """
+ def _memoized_attr_property(self) -> Optional[MapperProperty[Any]]:
return self.comparator.property
-def _queryable_attribute_unreduce(key, mapped_class, parententity, entity):
+def _queryable_attribute_unreduce(
+ key: str,
+ mapped_class: Type[_O],
+ parententity: _InternalEntityType[_O],
+ entity: _ExternalEntityType[Any],
+) -> Any:
# this method is only used in terms of the
# sqlalchemy.ext.serializer extension
- if parententity.is_aliased_class:
+ if insp_is_aliased_class(parententity):
return entity._get_from_serialized(key, mapped_class, parententity)
else:
return getattr(entity, key)
@@ -402,45 +501,60 @@ class InstrumentedAttribute(QueryableAttribute[_T]):
"""
+ __slots__ = ()
+
inherit_cache = True
- def __set__(self, instance, value):
+ # if not TYPE_CHECKING:
+
+ @property # type: ignore
+ def __doc__(self) -> Optional[str]: # type: ignore
+ return self._doc
+
+ @__doc__.setter
+ def __doc__(self, value: Optional[str]) -> None:
+ self._doc = value
+
+ def __set__(self, instance: object, value: Any) -> None:
self.impl.set(
instance_state(instance), instance_dict(instance), value, None
)
- def __delete__(self, instance):
+ def __delete__(self, instance: object) -> None:
self.impl.delete(instance_state(instance), instance_dict(instance))
@overload
- def __get__(
- self, instance: None, owner: Type[Any]
- ) -> InstrumentedAttribute:
+ def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]:
...
@overload
- def __get__(self, instance: object, owner: Type[Any]) -> Optional[_T]:
+ def __get__(self, instance: object, owner: Any) -> _T:
...
def __get__(
- self, instance: Optional[object], owner: Type[Any]
- ) -> Union[InstrumentedAttribute, Optional[_T]]:
+ self, instance: Optional[object], owner: Any
+ ) -> Union[InstrumentedAttribute[_T], _T]:
if instance is None:
return self
dict_ = instance_dict(instance)
- if self._supports_population and self.key in dict_:
- return dict_[self.key]
+ if self.impl.supports_population and self.key in dict_:
+ return dict_[self.key] # type: ignore[no-any-return]
else:
try:
state = instance_state(instance)
except AttributeError as err:
raise orm_exc.UnmappedInstanceError(instance) from err
- return self.impl.get(state, dict_)
+ return self.impl.get(state, dict_) # type: ignore[no-any-return]
-HasEntityNamespace = namedtuple("HasEntityNamespace", ["entity_namespace"])
-HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False
+@dataclasses.dataclass(frozen=True)
+class AdHocHasEntityNamespace:
+ # py37 compat, no slots=True on dataclass
+ __slots__ = ("entity_namespace",)
+ entity_namespace: _ExternalEntityType[Any]
+ is_mapper: ClassVar[bool] = False
+ is_aliased_class: ClassVar[bool] = False
def create_proxied_attribute(
@@ -455,7 +569,7 @@ def create_proxied_attribute(
# TODO: can move this to descriptor_props if the need for this
# function is removed from ext/hybrid.py
- class Proxy(QueryableAttribute):
+ class Proxy(QueryableAttribute[Any]):
"""Presents the :class:`.QueryableAttribute` interface as a
proxy on top of a Python descriptor / :class:`.PropComparator`
combination.
@@ -464,6 +578,10 @@ def create_proxied_attribute(
_extra_criteria = ()
+ # the attribute error catches inside of __getattr__ basically create a
+ # singularity if you try putting slots on this too
+ # __slots__ = ("descriptor", "original_property", "_comparator")
+
def __init__(
self,
class_,
@@ -480,7 +598,15 @@ def create_proxied_attribute(
self.original_property = original_property
self._comparator = comparator
self._adapt_to_entity = adapt_to_entity
- self.__doc__ = doc
+ self._doc = self.__doc__ = doc
+
+ @property
+ def _parententity(self):
+ return inspection.inspect(self.class_)
+
+ @property
+ def parent(self):
+ return inspection.inspect(self.class_)
_is_internal_proxy = True
@@ -497,17 +623,13 @@ def create_proxied_attribute(
)
@property
- def _parententity(self):
- return inspection.inspect(self.class_, raiseerr=False)
-
- @property
def _entity_namespace(self):
if hasattr(self._comparator, "_parententity"):
return self._comparator._parententity
else:
# used by hybrid attributes which try to remain
# agnostic of any ORM concepts like mappers
- return HasEntityNamespace(self.class_)
+ return AdHocHasEntityNamespace(self.class_)
@property
def property(self):
@@ -552,12 +674,22 @@ def create_proxied_attribute(
else:
return retval
- def __str__(self):
- return "%s.%s" % (self.class_.__name__, self.key)
+ def __str__(self) -> str:
+ return f"{self.class_.__name__}.{self.key}"
def __getattr__(self, attribute):
"""Delegate __getattr__ to the original descriptor and/or
comparator."""
+
+ # this is unfortunately very complicated, and is easily prone
+ # to recursion overflows when implementations of related
+ # __getattr__ schemes are changed
+
+ try:
+ return util.MemoizedSlots.__getattr__(self, attribute)
+ except AttributeError:
+ pass
+
try:
return getattr(descriptor, attribute)
except AttributeError as err:
@@ -602,7 +734,7 @@ OP_BULK_REPLACE = util.symbol("BULK_REPLACE")
OP_MODIFIED = util.symbol("MODIFIED")
-class AttributeEvent:
+class AttributeEventToken:
"""A token propagated throughout the course of a chain of attribute
events.
@@ -619,7 +751,8 @@ class AttributeEvent:
event handlers, and is used to control the propagation of operations
across two mutually-dependent attributes.
- .. versionadded:: 0.9.0
+ .. versionchanged:: 2.0 Changed the name from ``AttributeEvent``
+ to ``AttributeEventToken``.
:attribute impl: The :class:`.AttributeImpl` which is the current event
initiator.
@@ -639,7 +772,7 @@ class AttributeEvent:
def __eq__(self, other):
return (
- isinstance(other, AttributeEvent)
+ isinstance(other, AttributeEventToken)
and other.impl is self.impl
and other.op == self.op
)
@@ -652,28 +785,37 @@ class AttributeEvent:
return self.impl.hasparent(state)
-Event = AttributeEvent
+AttributeEvent = AttributeEventToken # legacy
+Event = AttributeEventToken # legacy
class AttributeImpl:
"""internal implementation for instrumented attributes."""
collection: bool
+ default_accepts_scalar_loader: bool
+ uses_objects: bool
+ supports_population: bool
+ dynamic: bool
+
+ _replace_token: AttributeEventToken
+ _remove_token: AttributeEventToken
+ _append_token: AttributeEventToken
def __init__(
self,
- class_,
- key,
- callable_,
- dispatch,
- trackparent=False,
- compare_function=None,
- active_history=False,
- parent_token=None,
- load_on_unexpire=True,
- send_modified_events=True,
- accepts_scalar_loader=None,
- **kwargs,
+ class_: _ExternalEntityType[_O],
+ key: str,
+ callable_: _LoaderCallable,
+ dispatch: _Dispatch[QueryableAttribute[Any]],
+ trackparent: bool = False,
+ compare_function: Optional[Callable[..., bool]] = None,
+ active_history: bool = False,
+ parent_token: Optional[AttributeEventToken] = None,
+ load_on_unexpire: bool = True,
+ send_modified_events: bool = True,
+ accepts_scalar_loader: Optional[bool] = None,
+ **kwargs: Any,
):
r"""Construct an AttributeImpl.
@@ -743,7 +885,7 @@ class AttributeImpl:
self.dispatch._active_history = True
self.load_on_unexpire = load_on_unexpire
- self._modified_token = Event(self, OP_MODIFIED)
+ self._modified_token = AttributeEventToken(self, OP_MODIFIED)
__slots__ = (
"class_",
@@ -760,8 +902,8 @@ class AttributeImpl:
"_deferred_history",
)
- def __str__(self):
- return "%s.%s" % (self.class_.__name__, self.key)
+ def __str__(self) -> str:
+ return f"{self.class_.__name__}.{self.key}"
def _get_active_history(self):
"""Backwards compat for impl.active_history"""
@@ -773,7 +915,9 @@ class AttributeImpl:
active_history = property(_get_active_history, _set_active_history)
- def hasparent(self, state, optimistic=False):
+ def hasparent(
+ self, state: InstanceState[Any], optimistic: bool = False
+ ) -> bool:
"""Return the boolean value of a `hasparent` flag attached to
the given state.
@@ -796,7 +940,12 @@ class AttributeImpl:
state.parents.get(id(self.parent_token), optimistic) is not False
)
- def sethasparent(self, state, parent_state, value):
+ def sethasparent(
+ self,
+ state: InstanceState[Any],
+ parent_state: InstanceState[Any],
+ value: bool,
+ ) -> None:
"""Set a boolean flag on the given item corresponding to
whether or not it is attached to a parent object via the
attribute represented by this ``InstrumentedAttribute``.
@@ -839,11 +988,16 @@ class AttributeImpl:
self,
state: InstanceState[Any],
dict_: _InstanceDict,
- passive=PASSIVE_OFF,
+ passive: PassiveFlag = PASSIVE_OFF,
) -> History:
raise NotImplementedError()
- def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ def get_all_pending(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_NO_INITIALIZE,
+ ) -> _AllPendingType:
"""Return a list of tuples of (state, obj)
for all objects in this attribute's current state
+ history.
@@ -861,7 +1015,9 @@ class AttributeImpl:
"""
raise NotImplementedError()
- def _default_value(self, state, dict_):
+ def _default_value(
+ self, state: InstanceState[Any], dict_: _InstanceDict
+ ) -> Any:
"""Produce an empty value for an uninitialized scalar attribute."""
assert self.key not in dict_, (
@@ -877,7 +1033,12 @@ class AttributeImpl:
return value
- def get(self, state, dict_, passive=PASSIVE_OFF):
+ def get(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Any:
"""Retrieve a value from the given object.
If a callable is assembled on this object's attribute, and
passive is False, the callable will be executed and the
@@ -917,7 +1078,9 @@ class AttributeImpl:
else:
return self._default_value(state, dict_)
- def _fire_loader_callables(self, state, key, passive):
+ def _fire_loader_callables(
+ self, state: InstanceState[Any], key: str, passive: PassiveFlag
+ ) -> Any:
if (
self.accepts_scalar_loader
and self.load_on_unexpire
@@ -932,15 +1095,36 @@ class AttributeImpl:
else:
return ATTR_EMPTY
- def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def append(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
self.set(state, dict_, value, initiator, passive=passive)
- def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def remove(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
self.set(
state, dict_, None, initiator, passive=passive, check_old=value
)
- def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def pop(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
self.set(
state,
dict_,
@@ -953,17 +1137,25 @@ class AttributeImpl:
def set(
self,
- state,
- dict_,
- value,
- initiator,
- passive=PASSIVE_OFF,
- check_old=None,
- pop=False,
- ):
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ check_old: Any = None,
+ pop: bool = False,
+ ) -> None:
+ raise NotImplementedError()
+
+ def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
raise NotImplementedError()
- def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
+ def get_committed_value(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Any:
"""return the unchanged value of this attribute"""
if self.key in state.committed_state:
@@ -996,10 +1188,12 @@ class ScalarAttributeImpl(AttributeImpl):
def __init__(self, *arg, **kw):
super(ScalarAttributeImpl, self).__init__(*arg, **kw)
- self._replace_token = self._append_token = Event(self, OP_REPLACE)
- self._remove_token = Event(self, OP_REMOVE)
+ self._replace_token = self._append_token = AttributeEventToken(
+ self, OP_REPLACE
+ )
+ self._remove_token = AttributeEventToken(self, OP_REMOVE)
- def delete(self, state, dict_):
+ def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
if self.dispatch._active_history:
old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
else:
@@ -1042,11 +1236,11 @@ class ScalarAttributeImpl(AttributeImpl):
state: InstanceState[Any],
dict_: Dict[str, Any],
value: Any,
- initiator: Optional[Event],
+ initiator: Optional[AttributeEventToken],
passive: PassiveFlag = PASSIVE_OFF,
check_old: Optional[object] = None,
pop: bool = False,
- ):
+ ) -> None:
if self.dispatch._active_history:
old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
else:
@@ -1059,21 +1253,30 @@ class ScalarAttributeImpl(AttributeImpl):
state._modified_event(dict_, self, old)
dict_[self.key] = value
- def fire_replace_event(self, state, dict_, value, previous, initiator):
+ def fire_replace_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: _T,
+ previous: Any,
+ initiator: Optional[AttributeEventToken],
+ ) -> _T:
for fn in self.dispatch.set:
value = fn(
state, value, previous, initiator or self._replace_token
)
return value
- def fire_remove_event(self, state, dict_, value, initiator):
+ def fire_remove_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ ) -> None:
for fn in self.dispatch.remove:
fn(state, value, initiator or self._remove_token)
- @property
- def type(self):
- self.property.columns[0].type
-
class ScalarObjectAttributeImpl(ScalarAttributeImpl):
"""represents a scalar-holding InstrumentedAttribute,
@@ -1090,7 +1293,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
__slots__ = ()
- def delete(self, state, dict_):
+ def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
if self.dispatch._active_history:
old = self.get(
state,
@@ -1122,7 +1325,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
):
raise AttributeError("%s object does not have a value" % self)
- def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> History:
if self.key in dict_:
current = dict_[self.key]
else:
@@ -1152,7 +1360,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
self, state, current, original=original
)
- def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ def get_all_pending(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_NO_INITIALIZE,
+ ) -> _AllPendingType:
if self.key in dict_:
current = dict_[self.key]
elif passive & CALLABLES_OK:
@@ -1160,6 +1373,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
else:
return []
+ ret: _AllPendingType
+
# can't use __hash__(), can't use __eq__() here
if (
current is not None
@@ -1184,14 +1399,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
def set(
self,
- state,
- dict_,
- value,
- initiator,
- passive=PASSIVE_OFF,
- check_old=None,
- pop=False,
- ):
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ check_old: Any = None,
+ pop: bool = False,
+ ) -> None:
"""Set a value on the given InstanceState."""
if self.dispatch._active_history:
@@ -1227,7 +1442,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
value = self.fire_replace_event(state, dict_, value, old, initiator)
dict_[self.key] = value
- def fire_remove_event(self, state, dict_, value, initiator):
+ def fire_remove_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ ) -> None:
if self.trackparent and value not in (
None,
PASSIVE_NO_RESULT,
@@ -1240,7 +1461,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
state._modified_event(dict_, self, value)
- def fire_replace_event(self, state, dict_, value, previous, initiator):
+ def fire_replace_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: _T,
+ previous: Any,
+ initiator: Optional[AttributeEventToken],
+ ) -> _T:
if self.trackparent:
if previous is not value and previous not in (
None,
@@ -1263,7 +1491,64 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
return value
-class CollectionAttributeImpl(AttributeImpl):
+class HasCollectionAdapter:
+ __slots__ = ()
+
+ def _dispose_previous_collection(
+ self,
+ state: InstanceState[Any],
+ collection: _AdaptedCollectionProtocol,
+ adapter: CollectionAdapter,
+ fire_event: bool,
+ ) -> None:
+ raise NotImplementedError()
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
+ ...
+
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
+ raise NotImplementedError()
+
+
+if TYPE_CHECKING:
+
+ def _is_collection_attribute_impl(
+ impl: AttributeImpl,
+ ) -> TypeGuard[CollectionAttributeImpl]:
+ ...
+
+else:
+ _is_collection_attribute_impl = operator.attrgetter("collection")
+
+
+class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
"""A collection-holding attribute that instruments changes in membership.
Only handles collections of instrumented objects.
@@ -1275,12 +1560,14 @@ class CollectionAttributeImpl(AttributeImpl):
"""
- default_accepts_scalar_loader = False
uses_objects = True
- supports_population = True
collection = True
+ default_accepts_scalar_loader = False
+ supports_population = True
dynamic = False
+ _bulk_replace_token: AttributeEventToken
+
__slots__ = (
"copy",
"collection_factory",
@@ -1316,9 +1603,9 @@ class CollectionAttributeImpl(AttributeImpl):
copy_function = self.__copy
self.copy = copy_function
self.collection_factory = typecallable
- self._append_token = Event(self, OP_APPEND)
- self._remove_token = Event(self, OP_REMOVE)
- self._bulk_replace_token = Event(self, OP_BULK_REPLACE)
+ self._append_token = AttributeEventToken(self, OP_APPEND)
+ self._remove_token = AttributeEventToken(self, OP_REMOVE)
+ self._bulk_replace_token = AttributeEventToken(self, OP_BULK_REPLACE)
self._duck_typed_as = util.duck_type_collection(
self.collection_factory()
)
@@ -1336,14 +1623,24 @@ class CollectionAttributeImpl(AttributeImpl):
def __copy(self, item):
return [y for y in collections.collection_adapter(item)]
- def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> History:
current = self.get(state, dict_, passive=passive)
if current is PASSIVE_NO_RESULT:
return HISTORY_BLANK
else:
return History.from_collection(self, state, current)
- def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ def get_all_pending(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_NO_INITIALIZE,
+ ) -> _AllPendingType:
# NOTE: passive is ignored here at the moment
if self.key not in dict_:
@@ -1383,7 +1680,13 @@ class CollectionAttributeImpl(AttributeImpl):
return [(instance_state(o), o) for o in current]
- def fire_append_event(self, state, dict_, value, initiator):
+ def fire_append_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: _T,
+ initiator: Optional[AttributeEventToken],
+ ) -> _T:
for fn in self.dispatch.append:
value = fn(state, value, initiator or self._append_token)
@@ -1394,13 +1697,24 @@ class CollectionAttributeImpl(AttributeImpl):
return value
- def fire_append_wo_mutation_event(self, state, dict_, value, initiator):
+ def fire_append_wo_mutation_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: _T,
+ initiator: Optional[AttributeEventToken],
+ ) -> _T:
for fn in self.dispatch.append_wo_mutation:
value = fn(state, value, initiator or self._append_token)
return value
- def fire_pre_remove_event(self, state, dict_, initiator):
+ def fire_pre_remove_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ initiator: Optional[AttributeEventToken],
+ ) -> None:
"""A special event used for pop() operations.
The "remove" event needs to have the item to be removed passed to
@@ -1411,7 +1725,13 @@ class CollectionAttributeImpl(AttributeImpl):
"""
state._modified_event(dict_, self, NO_VALUE, True)
- def fire_remove_event(self, state, dict_, value, initiator):
+ def fire_remove_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ ) -> None:
if self.trackparent and value is not None:
self.sethasparent(instance_state(value), state, False)
@@ -1420,7 +1740,7 @@ class CollectionAttributeImpl(AttributeImpl):
state._modified_event(dict_, self, NO_VALUE, True)
- def delete(self, state, dict_):
+ def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
if self.key not in dict_:
return
@@ -1433,7 +1753,9 @@ class CollectionAttributeImpl(AttributeImpl):
# del is a no-op if collection not present.
del dict_[self.key]
- def _default_value(self, state, dict_):
+ def _default_value(
+ self, state: InstanceState[Any], dict_: _InstanceDict
+ ) -> _AdaptedCollectionProtocol:
"""Produce an empty collection for an un-initialized attribute"""
assert self.key not in dict_, (
@@ -1448,7 +1770,9 @@ class CollectionAttributeImpl(AttributeImpl):
adapter._set_empty(user_data)
return user_data
- def _initialize_collection(self, state):
+ def _initialize_collection(
+ self, state: InstanceState[Any]
+ ) -> Tuple[CollectionAdapter, _AdaptedCollectionProtocol]:
adapter, collection = state.manager.initialize_collection(
self.key, state, self.collection_factory
@@ -1458,7 +1782,14 @@ class CollectionAttributeImpl(AttributeImpl):
return adapter, collection
- def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def append(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
collection = self.get_collection(state, dict_, passive=passive)
if collection is PASSIVE_NO_RESULT:
value = self.fire_append_event(state, dict_, value, initiator)
@@ -1467,9 +1798,18 @@ class CollectionAttributeImpl(AttributeImpl):
), "Collection was loaded during event handling."
state._get_pending_mutation(self.key).append(value)
else:
+ if TYPE_CHECKING:
+ assert isinstance(collection, CollectionAdapter)
collection.append_with_event(value, initiator)
- def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def remove(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
collection = self.get_collection(state, state.dict, passive=passive)
if collection is PASSIVE_NO_RESULT:
self.fire_remove_event(state, dict_, value, initiator)
@@ -1478,9 +1818,18 @@ class CollectionAttributeImpl(AttributeImpl):
), "Collection was loaded during event handling."
state._get_pending_mutation(self.key).remove(value)
else:
+ if TYPE_CHECKING:
+ assert isinstance(collection, CollectionAdapter)
collection.remove_with_event(value, initiator)
- def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def pop(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
try:
# TODO: better solution here would be to add
# a "popper" role to collections.py to complement
@@ -1491,15 +1840,15 @@ class CollectionAttributeImpl(AttributeImpl):
def set(
self,
- state,
- dict_,
- value,
- initiator=None,
- passive=PASSIVE_OFF,
- check_old=None,
- pop=False,
- _adapt=True,
- ):
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken] = None,
+ passive: PassiveFlag = PASSIVE_OFF,
+ check_old: Any = None,
+ pop: bool = False,
+ _adapt: bool = True,
+ ) -> None:
iterable = orig_iterable = value
# pulling a new collection first so that an adaptation exception does
@@ -1518,7 +1867,7 @@ class CollectionAttributeImpl(AttributeImpl):
and "None"
or iterable.__class__.__name__
)
- wanted = self._duck_typed_as.__name__
+ wanted = self._duck_typed_as.__name__ # type: ignore
raise TypeError(
"Incompatible collection type: %s is not %s-like"
% (given, wanted)
@@ -1560,8 +1909,12 @@ class CollectionAttributeImpl(AttributeImpl):
self._dispose_previous_collection(state, old, old_collection, True)
def _dispose_previous_collection(
- self, state, collection, adapter, fire_event
- ):
+ self,
+ state: InstanceState[Any],
+ collection: _AdaptedCollectionProtocol,
+ adapter: CollectionAdapter,
+ fire_event: bool,
+ ) -> None:
del collection._sa_adapter
# discarding old collection make sure it is not referenced in empty
@@ -1570,11 +1923,15 @@ class CollectionAttributeImpl(AttributeImpl):
if fire_event:
self.dispatch.dispose_collection(state, collection, adapter)
- def _invalidate_collection(self, collection: Collection) -> None:
+ def _invalidate_collection(
+ self, collection: _AdaptedCollectionProtocol
+ ) -> None:
adapter = getattr(collection, "_sa_adapter")
adapter.invalidated = True
- def set_committed_value(self, state, dict_, value):
+ def set_committed_value(
+ self, state: InstanceState[Any], dict_: _InstanceDict, value: Any
+ ) -> _AdaptedCollectionProtocol:
"""Set an attribute value on the given instance and 'commit' it."""
collection, user_data = self._initialize_collection(state)
@@ -1601,9 +1958,37 @@ class CollectionAttributeImpl(AttributeImpl):
return user_data
+ @overload
def get_collection(
- self, state, dict_, user_data=None, passive=PASSIVE_OFF
- ):
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
+ ...
+
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
"""Retrieve the CollectionAdapter associated with the given state.
if user_data is None, retrieves it from the state using normal
@@ -1612,14 +1997,18 @@ class CollectionAttributeImpl(AttributeImpl):
"""
if user_data is None:
- user_data = self.get(state, dict_, passive=passive)
- if user_data is PASSIVE_NO_RESULT:
- return user_data
+ fetch_user_data = self.get(state, dict_, passive=passive)
+ if fetch_user_data is LoaderCallableStatus.PASSIVE_NO_RESULT:
+ return fetch_user_data
+ else:
+ user_data = cast("_AdaptedCollectionProtocol", fetch_user_data)
return user_data._sa_adapter
-def backref_listeners(attribute, key, uselist):
+def backref_listeners(
+ attribute: QueryableAttribute[Any], key: str, uselist: bool
+) -> None:
"""Apply listeners to synchronize a two-way relationship."""
# use easily recognizable names for stack traces.
@@ -1695,7 +2084,7 @@ def backref_listeners(attribute, key, uselist):
check_append_token = child_impl._append_token
check_bulk_replace_token = (
child_impl._bulk_replace_token
- if child_impl.collection
+ if _is_collection_attribute_impl(child_impl)
else None
)
@@ -1728,7 +2117,9 @@ def backref_listeners(attribute, key, uselist):
# tokens to test for a recursive loop.
check_append_token = child_impl._append_token
check_bulk_replace_token = (
- child_impl._bulk_replace_token if child_impl.collection else None
+ child_impl._bulk_replace_token
+ if _is_collection_attribute_impl(child_impl)
+ else None
)
if (
@@ -1756,6 +2147,8 @@ def backref_listeners(attribute, key, uselist):
)
child_impl = child_state.manager[key].impl
+ check_replace_token: Optional[AttributeEventToken]
+
# tokens to test for a recursive loop.
if not child_impl.collection and not child_impl.dynamic:
check_remove_token = child_impl._remove_token
@@ -1765,7 +2158,7 @@ def backref_listeners(attribute, key, uselist):
check_remove_token = child_impl._remove_token
check_replace_token = (
child_impl._bulk_replace_token
- if child_impl.collection
+ if _is_collection_attribute_impl(child_impl)
else None
)
check_for_dupes_on_remove = False
@@ -1848,10 +2241,10 @@ class History(NamedTuple):
unchanged: Union[Tuple[()], List[Any]]
deleted: Union[Tuple[()], List[Any]]
- def __bool__(self):
+ def __bool__(self) -> bool:
return self != HISTORY_BLANK
- def empty(self):
+ def empty(self) -> bool:
"""Return True if this :class:`.History` has no changes
and no existing, unchanged state.
@@ -1859,29 +2252,29 @@ class History(NamedTuple):
return not bool((self.added or self.deleted) or self.unchanged)
- def sum(self):
+ def sum(self) -> Sequence[Any]:
"""Return a collection of added + unchanged + deleted."""
return (
(self.added or []) + (self.unchanged or []) + (self.deleted or [])
)
- def non_deleted(self):
+ def non_deleted(self) -> Sequence[Any]:
"""Return a collection of added + unchanged."""
return (self.added or []) + (self.unchanged or [])
- def non_added(self):
+ def non_added(self) -> Sequence[Any]:
"""Return a collection of unchanged + deleted."""
return (self.unchanged or []) + (self.deleted or [])
- def has_changes(self):
+ def has_changes(self) -> bool:
"""Return True if this :class:`.History` has changes."""
return bool(self.added or self.deleted)
- def as_state(self):
+ def as_state(self) -> History:
return History(
[
(c is not None) and instance_state(c) or None
@@ -1898,9 +2291,16 @@ class History(NamedTuple):
)
@classmethod
- def from_scalar_attribute(cls, attribute, state, current):
+ def from_scalar_attribute(
+ cls,
+ attribute: ScalarAttributeImpl,
+ state: InstanceState[Any],
+ current: Any,
+ ) -> History:
original = state.committed_state.get(attribute.key, _NO_HISTORY)
+ deleted: Union[Tuple[()], List[Any]]
+
if original is _NO_HISTORY:
if current is NO_VALUE:
return cls((), (), ())
@@ -1933,8 +2333,14 @@ class History(NamedTuple):
@classmethod
def from_object_attribute(
- cls, attribute, state, current, original=_NO_HISTORY
- ):
+ cls,
+ attribute: ScalarObjectAttributeImpl,
+ state: InstanceState[Any],
+ current: Any,
+ original: Any = _NO_HISTORY,
+ ) -> History:
+ deleted: Union[Tuple[()], List[Any]]
+
if original is _NO_HISTORY:
original = state.committed_state.get(attribute.key, _NO_HISTORY)
@@ -1965,7 +2371,12 @@ class History(NamedTuple):
return cls([current], (), deleted)
@classmethod
- def from_collection(cls, attribute, state, current):
+ def from_collection(
+ cls,
+ attribute: CollectionAttributeImpl,
+ state: InstanceState[Any],
+ current: Any,
+ ) -> History:
original = state.committed_state.get(attribute.key, _NO_HISTORY)
if current is NO_VALUE:
return cls((), (), ())
@@ -1999,7 +2410,9 @@ class History(NamedTuple):
HISTORY_BLANK = History((), (), ())
-def get_history(obj, key, passive=PASSIVE_OFF):
+def get_history(
+ obj: object, key: str, passive: PassiveFlag = PASSIVE_OFF
+) -> History:
"""Return a :class:`.History` record for the given object
and attribute key.
@@ -2037,36 +2450,47 @@ def get_history(obj, key, passive=PASSIVE_OFF):
return get_state_history(instance_state(obj), key, passive)
-def get_state_history(state, key, passive=PASSIVE_OFF):
+def get_state_history(
+ state: InstanceState[Any], key: str, passive: PassiveFlag = PASSIVE_OFF
+) -> History:
return state.get_history(key, passive)
-def has_parent(cls, obj, key, optimistic=False):
+def has_parent(
+ cls: Type[_O], obj: _O, key: str, optimistic: bool = False
+) -> bool:
"""TODO"""
manager = manager_of_class(cls)
state = instance_state(obj)
return manager.has_parent(state, key, optimistic)
-def register_attribute(class_, key, **kw):
- comparator = kw.pop("comparator", None)
- parententity = kw.pop("parententity", None)
- doc = kw.pop("doc", None)
- desc = register_descriptor(class_, key, comparator, parententity, doc=doc)
+def register_attribute(
+ class_: Type[_O],
+ key: str,
+ *,
+ comparator: interfaces.PropComparator[_T],
+ parententity: _InternalEntityType[_O],
+ doc: Optional[str] = None,
+ **kw: Any,
+) -> InstrumentedAttribute[_T]:
+ desc = register_descriptor(
+ class_, key, comparator=comparator, parententity=parententity, doc=doc
+ )
register_attribute_impl(class_, key, **kw)
return desc
def register_attribute_impl(
- class_,
- key,
- uselist=False,
- callable_=None,
- useobject=False,
- impl_class=None,
- backref=None,
- **kw,
-):
+ class_: Type[_O],
+ key: str,
+ uselist: bool = False,
+ callable_: Optional[_LoaderCallable] = None,
+ useobject: bool = False,
+ impl_class: Optional[Type[AttributeImpl]] = None,
+ backref: Optional[str] = None,
+ **kw: Any,
+) -> InstrumentedAttribute[Any]:
manager = manager_of_class(class_)
if uselist:
@@ -2077,10 +2501,18 @@ def register_attribute_impl(
else:
typecallable = kw.pop("typecallable", None)
- dispatch = manager[key].dispatch
+ dispatch = cast(
+ "_Dispatch[QueryableAttribute[Any]]", manager[key].dispatch
+ ) # noqa: E501
+
+ impl: AttributeImpl
if impl_class:
- impl = impl_class(class_, key, typecallable, dispatch, **kw)
+ # TODO: this appears to be the DynamicAttributeImpl constructor
+ # which is hardcoded
+ impl = cast("Type[DynamicAttributeImpl]", impl_class)(
+ class_, key, typecallable, dispatch, **kw
+ )
elif uselist:
impl = CollectionAttributeImpl(
class_, key, callable_, dispatch, typecallable=typecallable, **kw
@@ -2102,8 +2534,13 @@ def register_attribute_impl(
def register_descriptor(
- class_, key, comparator=None, parententity=None, doc=None
-):
+ class_: Type[Any],
+ key: str,
+ *,
+ comparator: interfaces.PropComparator[_T],
+ parententity: _InternalEntityType[Any],
+ doc: Optional[str] = None,
+) -> InstrumentedAttribute[_T]:
manager = manager_of_class(class_)
descriptor = InstrumentedAttribute(
@@ -2116,11 +2553,11 @@ def register_descriptor(
return descriptor
-def unregister_attribute(class_, key):
+def unregister_attribute(class_: Type[Any], key: str) -> None:
manager_of_class(class_).uninstrument_attribute(key)
-def init_collection(obj, key):
+def init_collection(obj: object, key: str) -> CollectionAdapter:
"""Initialize a collection attribute and return the collection adapter.
This function is used to provide direct access to collection internals
@@ -2143,7 +2580,9 @@ def init_collection(obj, key):
return init_state_collection(state, dict_, key)
-def init_state_collection(state, dict_, key):
+def init_state_collection(
+ state: InstanceState[Any], dict_: _InstanceDict, key: str
+) -> CollectionAdapter:
"""Initialize a collection attribute and return the collection adapter.
Discards any existing collection which may be there.
@@ -2151,6 +2590,9 @@ def init_state_collection(state, dict_, key):
"""
attr = state.manager[key].impl
+ if TYPE_CHECKING:
+ assert isinstance(attr, HasCollectionAdapter)
+
old = dict_.pop(key, None) # discard old collection
if old is not None:
old_collection = old._sa_adapter
@@ -2182,7 +2624,12 @@ def set_committed_value(instance, key, value):
state.manager[key].impl.set_committed_value(state, dict_, value)
-def set_attribute(instance, key, value, initiator=None):
+def set_attribute(
+ instance: object,
+ key: str,
+ value: Any,
+ initiator: Optional[AttributeEventToken] = None,
+) -> None:
"""Set the value of an attribute, firing history events.
This function may be used regardless of instrumentation
@@ -2211,7 +2658,7 @@ def set_attribute(instance, key, value, initiator=None):
state.manager[key].impl.set(state, dict_, value, initiator)
-def get_attribute(instance, key):
+def get_attribute(instance: object, key: str) -> Any:
"""Get the value of an attribute, firing any callables required.
This function may be used regardless of instrumentation
@@ -2225,7 +2672,7 @@ def get_attribute(instance, key):
return state.manager[key].impl.get(state, dict_)
-def del_attribute(instance, key):
+def del_attribute(instance: object, key: str) -> None:
"""Delete the value of an attribute, firing history events.
This function may be used regardless of instrumentation
@@ -2239,7 +2686,7 @@ def del_attribute(instance, key):
state.manager[key].impl.delete(state, dict_)
-def flag_modified(instance, key):
+def flag_modified(instance: object, key: str) -> None:
"""Mark an attribute on an instance as 'modified'.
This sets the 'modified' flag on the instance and
@@ -2262,7 +2709,7 @@ def flag_modified(instance, key):
state._modified_event(dict_, impl, NO_VALUE, is_userland=True)
-def flag_dirty(instance):
+def flag_dirty(instance: object) -> None:
"""Mark an instance as 'dirty' without any specific attribute mentioned.
This is a special operation that will allow the object to travel through