summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/event/base.py6
-rw-r--r--lib/sqlalchemy/ext/hybrid.py13
-rw-r--r--lib/sqlalchemy/orm/__init__.py2
-rw-r--r--lib/sqlalchemy/orm/_typing.py2
-rw-r--r--lib/sqlalchemy/orm/attributes.py925
-rw-r--r--lib/sqlalchemy/orm/base.py25
-rw-r--r--lib/sqlalchemy/orm/collections.py208
-rw-r--r--lib/sqlalchemy/orm/dynamic.py8
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py121
-rw-r--r--lib/sqlalchemy/orm/interfaces.py45
-rw-r--r--lib/sqlalchemy/orm/relationships.py77
-rw-r--r--lib/sqlalchemy/orm/state.py9
-rw-r--r--lib/sqlalchemy/orm/util.py16
-rw-r--r--lib/sqlalchemy/sql/base.py4
-rw-r--r--lib/sqlalchemy/sql/cache_key.py8
-rw-r--r--lib/sqlalchemy/sql/util.py11
-rw-r--r--lib/sqlalchemy/util/_collections.py4
-rw-r--r--lib/sqlalchemy/util/langhelpers.py13
-rw-r--r--lib/sqlalchemy/util/typing.py39
19 files changed, 1101 insertions, 435 deletions
diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py
index c16f6870b..83b34a17f 100644
--- a/lib/sqlalchemy/event/base.py
+++ b/lib/sqlalchemy/event/base.py
@@ -108,10 +108,12 @@ class _Dispatch(_DispatchCommon[_ET]):
"""
- # In one ORM edge case, an attribute is added to _Dispatch,
- # so __dict__ is used in just that case and potentially others.
+ # "active_history" is an ORM case we add here. ideally a better
+ # system would be in place for ad-hoc attributes.
__slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners"
+ _active_history: bool
+
_empty_listener_reg: MutableMapping[
Type[_ET], Dict[str, _EmptyListener[_ET]]
] = weakref.WeakKeyDictionary()
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
index 7200414a1..ea558495b 100644
--- a/lib/sqlalchemy/ext/hybrid.py
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -824,15 +824,14 @@ from ..orm import attributes
from ..orm import InspectionAttrExtensionType
from ..orm import interfaces
from ..orm import ORMDescriptor
+from ..sql import roles
from ..sql._typing import is_has_clause_element
from ..sql.elements import ColumnElement
from ..sql.elements import SQLCoreOperations
from ..util.typing import Literal
from ..util.typing import Protocol
-
if TYPE_CHECKING:
- from ..orm._typing import _ORMColumnExprArgument
from ..orm.interfaces import MapperProperty
from ..orm.util import AliasedInsp
from ..sql._typing import _ColumnExpressionArgument
@@ -840,7 +839,6 @@ if TYPE_CHECKING:
from ..sql._typing import _HasClauseElement
from ..sql._typing import _InfoType
from ..sql.operators import OperatorType
- from ..sql.roles import ColumnsClauseRole
_T = TypeVar("_T", bound=Any)
_T_co = TypeVar("_T_co", bound=Any, covariant=True)
@@ -1290,7 +1288,7 @@ class Comparator(interfaces.PropComparator[_T]):
):
self.expression = expression
- def __clause_element__(self) -> _ORMColumnExprArgument[_T]:
+ def __clause_element__(self) -> roles.ColumnsClauseRole:
expr = self.expression
if is_has_clause_element(expr):
ret_expr = expr.__clause_element__()
@@ -1306,7 +1304,7 @@ class Comparator(interfaces.PropComparator[_T]):
assert isinstance(ret_expr, ColumnElement)
return ret_expr
- @util.ro_non_memoized_property
+ @util.non_memoized_property
def property(self) -> Optional[interfaces.MapperProperty[_T]]:
return None
@@ -1345,8 +1343,11 @@ class ExprComparator(Comparator[_T]):
else:
return [(self.expression, value)]
- @util.ro_non_memoized_property
+ @util.non_memoized_property
def property(self) -> Optional[MapperProperty[_T]]:
+ # this accessor is not normally used, however is accessed by things
+ # like ORM synonyms if the hybrid is used in this context; the
+ # .property attribute is not necessarily accessible
return self.expression.property # type: ignore
def operate(
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 58900ab99..b7d1df532 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -41,7 +41,7 @@ from ._orm_constructors import synonym as synonym
from ._orm_constructors import SynonymProperty as SynonymProperty
from ._orm_constructors import with_loader_criteria as with_loader_criteria
from ._orm_constructors import with_polymorphic as with_polymorphic
-from .attributes import AttributeEvent as AttributeEvent
+from .attributes import AttributeEventToken as AttributeEventToken
from .attributes import InstrumentedAttribute as InstrumentedAttribute
from .attributes import QueryableAttribute as QueryableAttribute
from .base import class_mapper as class_mapper
diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py
index 339844f14..29d82340a 100644
--- a/lib/sqlalchemy/orm/_typing.py
+++ b/lib/sqlalchemy/orm/_typing.py
@@ -47,6 +47,8 @@ if TYPE_CHECKING:
_InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"]
+_ExternalEntityType = Union[Type[_T], "AliasedClass[_T]"]
+
_EntityType = Union[
Type[_T], "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"
]
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 9a6e94e22..9aeaeaa27 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -4,7 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
"""Defines instrumentation for class attributes and their interaction
with instances.
@@ -17,16 +17,18 @@ defines a large part of the ORM's interactivity.
from __future__ import annotations
-from collections import namedtuple
+import dataclasses
import operator
from typing import Any
from typing import Callable
-from typing import Collection
+from typing import cast
+from typing import ClassVar
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import overload
+from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
@@ -36,6 +38,7 @@ from typing import Union
from . import collections
from . import exc as orm_exc
from . import interfaces
+from ._typing import insp_is_aliased_class
from .base import ATTR_EMPTY
from .base import ATTR_WAS_SET
from .base import CALLABLES_OK
@@ -45,6 +48,7 @@ from .base import instance_dict as instance_dict
from .base import instance_state as instance_state
from .base import instance_str
from .base import LOAD_AGAINST_COMMITTED
+from .base import LoaderCallableStatus
from .base import manager_of_class as manager_of_class
from .base import Mapped as Mapped # noqa
from .base import NEVER_SET # noqa
@@ -70,17 +74,41 @@ from .. import event
from .. import exc
from .. import inspection
from .. import util
+from ..event import dispatcher
+from ..event import EventTarget
from ..sql import base as sql_base
from ..sql import cache_key
+from ..sql import coercions
from ..sql import roles
-from ..sql import traversals
from ..sql import visitors
+from ..util.typing import Literal
+from ..util.typing import TypeGuard
if TYPE_CHECKING:
+ from ._typing import _EntityType
+ from ._typing import _ExternalEntityType
+ from ._typing import _InstanceDict
+ from ._typing import _InternalEntityType
+ from ._typing import _LoaderCallable
+ from ._typing import _O
+ from .collections import _AdaptedCollectionProtocol
+ from .collections import CollectionAdapter
+ from .dynamic import DynamicAttributeImpl
from .interfaces import MapperProperty
+ from .relationships import Relationship
from .state import InstanceState
- from ..sql.dml import _DMLColumnElement
+ from .util import AliasedInsp
+ from ..event.base import _Dispatch
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _DMLColumnArgument
+ from ..sql._typing import _InfoType
+ from ..sql._typing import _PropagateAttrsType
+ from ..sql.annotation import _AnnotationDict
from ..sql.elements import ColumnElement
+ from ..sql.elements import Label
+ from ..sql.operators import OperatorType
+ from ..sql.selectable import FromClause
+
_T = TypeVar("_T")
@@ -89,19 +117,27 @@ class NoKey(str):
pass
+_AllPendingType = List[Tuple[Optional["InstanceState[Any]"], Optional[object]]]
+
NO_KEY = NoKey("no name")
+SelfQueryableAttribute = TypeVar(
+ "SelfQueryableAttribute", bound="QueryableAttribute[Any]"
+)
+
@inspection._self_inspects
class QueryableAttribute(
+ roles.ExpressionElementRole[_T],
interfaces._MappedAttribute[_T],
interfaces.InspectionAttr,
interfaces.PropComparator[_T],
- traversals.HasCopyInternals,
roles.JoinTargetRole,
roles.OnClauseRole,
sql_base.Immutable,
- cache_key.MemoizedHasCacheKey,
+ cache_key.SlotsMemoizedHasCacheKey,
+ util.MemoizedSlots,
+ EventTarget,
):
"""Base class for :term:`descriptor` objects that intercept
attribute events on behalf of a :class:`.MapperProperty`
@@ -121,9 +157,33 @@ class QueryableAttribute(
:attr:`_orm.Mapper.attrs`
"""
+ __slots__ = (
+ "class_",
+ "key",
+ "impl",
+ "comparator",
+ "property",
+ "parent",
+ "expression",
+ "_of_type",
+ "_extra_criteria",
+ "_slots_dispatch",
+ "_propagate_attrs",
+ "_doc",
+ )
+
is_attribute = True
+ dispatch: dispatcher[QueryableAttribute[_T]]
+
+ class_: _ExternalEntityType[Any]
+ key: str
+ parententity: _InternalEntityType[Any]
impl: AttributeImpl
+ comparator: interfaces.PropComparator[_T]
+ _of_type: Optional[_InternalEntityType[Any]]
+ _extra_criteria: Tuple[ColumnElement[bool], ...]
+ _doc: Optional[str]
# PropComparator has a __visit_name__ to participate within
# traversals. Disambiguate the attribute vs. a comparator.
@@ -131,21 +191,30 @@ class QueryableAttribute(
def __init__(
self,
- class_,
- key,
- parententity,
- impl=None,
- comparator=None,
- of_type=None,
- extra_criteria=(),
+ class_: _ExternalEntityType[_O],
+ key: str,
+ parententity: _InternalEntityType[_O],
+ comparator: interfaces.PropComparator[_T],
+ impl: Optional[AttributeImpl] = None,
+ of_type: Optional[_InternalEntityType[Any]] = None,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
):
self.class_ = class_
self.key = key
- self._parententity = parententity
- self.impl = impl
+
+ self._parententity = self.parent = parententity
+
+ # this attribute is non-None after mappers are set up, however in the
+ # interim class manager setup, there's a check for None to see if it
+ # needs to be populated, so we assign None here leaving the attribute
+ # in a temporarily not-type-correct state
+ self.impl = impl # type: ignore
+
+ assert comparator is not None
self.comparator = comparator
self._of_type = of_type
self._extra_criteria = extra_criteria
+ self._doc = None
manager = opt_manager_of_class(class_)
# manager is None in the case of AliasedClass
@@ -156,7 +225,7 @@ class QueryableAttribute(
if key in base:
self.dispatch._update(base[key].dispatch)
if base[key].dispatch._active_history:
- self.dispatch._active_history = True
+ self.dispatch._active_history = True # type: ignore
_cache_key_traversal = [
("key", visitors.ExtendedInternalTraversal.dp_string),
@@ -165,7 +234,7 @@ class QueryableAttribute(
("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
]
- def __reduce__(self):
+ def __reduce__(self) -> Any:
# this method is only used in terms of the
# sqlalchemy.ext.serializer extension
return (
@@ -178,21 +247,19 @@ class QueryableAttribute(
),
)
- @util.memoized_property
- def _supports_population(self):
- return self.impl.supports_population
-
@property
- def _impl_uses_objects(self):
+ def _impl_uses_objects(self) -> bool:
return self.impl.uses_objects
- def get_history(self, instance, passive=PASSIVE_OFF):
+ def get_history(
+ self, instance: Any, passive: PassiveFlag = PASSIVE_OFF
+ ) -> History:
return self.impl.get_history(
instance_state(instance), instance_dict(instance), passive
)
- @util.memoized_property
- def info(self):
+ @property
+ def info(self) -> _InfoType:
"""Return the 'info' dictionary for the underlying SQL element.
The behavior here is as follows:
@@ -233,27 +300,28 @@ class QueryableAttribute(
"""
return self.comparator.info
- @util.memoized_property
- def parent(self):
- """Return an inspection instance representing the parent.
+ parent: _InternalEntityType[Any]
+ """Return an inspection instance representing the parent.
- This will be either an instance of :class:`_orm.Mapper`
- or :class:`.AliasedInsp`, depending upon the nature
- of the parent entity which this attribute is associated
- with.
+ This will be either an instance of :class:`_orm.Mapper`
+ or :class:`.AliasedInsp`, depending upon the nature
+ of the parent entity which this attribute is associated
+ with.
- """
- return inspection.inspect(self._parententity)
+ """
- @util.memoized_property
- def expression(self):
- """The SQL expression object represented by this
- :class:`.QueryableAttribute`.
+ expression: ColumnElement[_T]
+ """The SQL expression object represented by this
+ :class:`.QueryableAttribute`.
- This will typically be an instance of a :class:`_sql.ColumnElement`
- subclass representing a column expression.
+ This will typically be an instance of a :class:`_sql.ColumnElement`
+ subclass representing a column expression.
+
+ """
+
+ def _memoized_attr_expression(self) -> ColumnElement[_T]:
+ annotations: _AnnotationDict
- """
if self.key is NO_KEY:
annotations = {"entity_namespace": self._entity_namespace}
else:
@@ -265,6 +333,8 @@ class QueryableAttribute(
ce = self.comparator.__clause_element__()
try:
+ if TYPE_CHECKING:
+ assert isinstance(ce, ColumnElement)
anno = ce._annotate
except AttributeError as ae:
raise exc.InvalidRequestError(
@@ -275,29 +345,42 @@ class QueryableAttribute(
else:
return anno(annotations)
+ def _memoized_attr__propagate_attrs(self) -> _PropagateAttrsType:
+ # this suits the case in coercions where we don't actually
+ # call ``__clause_element__()`` but still need to get
+ # resolved._propagate_attrs. See #6558.
+ return util.immutabledict(
+ {
+ "compile_state_plugin": "orm",
+ "plugin_subject": self._parentmapper,
+ }
+ )
+
@property
- def _entity_namespace(self):
+ def _entity_namespace(self) -> _InternalEntityType[Any]:
return self._parententity
@property
- def _annotations(self):
+ def _annotations(self) -> _AnnotationDict:
return self.__clause_element__()._annotations
def __clause_element__(self) -> ColumnElement[_T]:
return self.expression
@property
- def _from_objects(self):
+ def _from_objects(self) -> List[FromClause]:
return self.expression._from_objects
def _bulk_update_tuples(
self, value: Any
- ) -> List[Tuple[_DMLColumnElement, Any]]:
+ ) -> Sequence[Tuple[_DMLColumnArgument, Any]]:
"""Return setter tuples for a bulk UPDATE."""
return self.comparator._bulk_update_tuples(value)
- def adapt_to_entity(self, adapt_to_entity):
+ def adapt_to_entity(
+ self: SelfQueryableAttribute, adapt_to_entity: AliasedInsp[Any]
+ ) -> SelfQueryableAttribute:
assert not self._of_type
return self.__class__(
adapt_to_entity.entity,
@@ -307,7 +390,7 @@ class QueryableAttribute(
parententity=adapt_to_entity,
)
- def of_type(self, entity):
+ def of_type(self, entity: _EntityType[_T]) -> QueryableAttribute[_T]:
return QueryableAttribute(
self.class_,
self.key,
@@ -318,18 +401,28 @@ class QueryableAttribute(
extra_criteria=self._extra_criteria,
)
- def and_(self, *other):
+ def and_(
+ self, *clauses: _ColumnExpressionArgument[bool]
+ ) -> interfaces.PropComparator[bool]:
+ if TYPE_CHECKING:
+ assert isinstance(self.comparator, Relationship.Comparator)
+
+ exprs = tuple(
+ coercions.expect(roles.WhereHavingRole, clause)
+ for clause in util.coerce_generator_arg(clauses)
+ )
+
return QueryableAttribute(
self.class_,
self.key,
self._parententity,
impl=self.impl,
- comparator=self.comparator.and_(*other),
+ comparator=self.comparator.and_(*exprs),
of_type=self._of_type,
- extra_criteria=self._extra_criteria + other,
+ extra_criteria=self._extra_criteria + exprs,
)
- def _clone(self, **kw):
+ def _clone(self, **kw: Any) -> QueryableAttribute[_T]:
return QueryableAttribute(
self.class_,
self.key,
@@ -340,19 +433,30 @@ class QueryableAttribute(
extra_criteria=self._extra_criteria,
)
- def label(self, name):
+ def label(self, name: Optional[str]) -> Label[_T]:
return self.__clause_element__().label(name)
- def operate(self, op, *other, **kwargs):
- return op(self.comparator, *other, **kwargs)
+ def operate(
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501
- def reverse_operate(self, op, other, **kwargs):
- return op(other, self.comparator, **kwargs)
+ def reverse_operate(
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501
- def hasparent(self, state, optimistic=False):
+ def hasparent(
+ self, state: InstanceState[Any], optimistic: bool = False
+ ) -> bool:
return self.impl.hasparent(state, optimistic=optimistic) is not False
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
+ try:
+ return util.MemoizedSlots.__getattr__(self, key)
+ except AttributeError:
+ pass
+
try:
return getattr(self.comparator, key)
except AttributeError as err:
@@ -367,27 +471,22 @@ class QueryableAttribute(
)
) from err
- def __str__(self):
- return "%s.%s" % (self.class_.__name__, self.key)
+ def __str__(self) -> str:
+ return f"{self.class_.__name__}.{self.key}"
- @util.memoized_property
- def property(self) -> MapperProperty[_T]:
- """Return the :class:`.MapperProperty` associated with this
- :class:`.QueryableAttribute`.
-
-
- Return values here will commonly be instances of
- :class:`.ColumnProperty` or :class:`.Relationship`.
-
-
- """
+ def _memoized_attr_property(self) -> Optional[MapperProperty[Any]]:
return self.comparator.property
-def _queryable_attribute_unreduce(key, mapped_class, parententity, entity):
+def _queryable_attribute_unreduce(
+ key: str,
+ mapped_class: Type[_O],
+ parententity: _InternalEntityType[_O],
+ entity: _ExternalEntityType[Any],
+) -> Any:
# this method is only used in terms of the
# sqlalchemy.ext.serializer extension
- if parententity.is_aliased_class:
+ if insp_is_aliased_class(parententity):
return entity._get_from_serialized(key, mapped_class, parententity)
else:
return getattr(entity, key)
@@ -402,45 +501,60 @@ class InstrumentedAttribute(QueryableAttribute[_T]):
"""
+ __slots__ = ()
+
inherit_cache = True
- def __set__(self, instance, value):
+ # if not TYPE_CHECKING:
+
+ @property # type: ignore
+ def __doc__(self) -> Optional[str]: # type: ignore
+ return self._doc
+
+ @__doc__.setter
+ def __doc__(self, value: Optional[str]) -> None:
+ self._doc = value
+
+ def __set__(self, instance: object, value: Any) -> None:
self.impl.set(
instance_state(instance), instance_dict(instance), value, None
)
- def __delete__(self, instance):
+ def __delete__(self, instance: object) -> None:
self.impl.delete(instance_state(instance), instance_dict(instance))
@overload
- def __get__(
- self, instance: None, owner: Type[Any]
- ) -> InstrumentedAttribute:
+ def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]:
...
@overload
- def __get__(self, instance: object, owner: Type[Any]) -> Optional[_T]:
+ def __get__(self, instance: object, owner: Any) -> _T:
...
def __get__(
- self, instance: Optional[object], owner: Type[Any]
- ) -> Union[InstrumentedAttribute, Optional[_T]]:
+ self, instance: Optional[object], owner: Any
+ ) -> Union[InstrumentedAttribute[_T], _T]:
if instance is None:
return self
dict_ = instance_dict(instance)
- if self._supports_population and self.key in dict_:
- return dict_[self.key]
+ if self.impl.supports_population and self.key in dict_:
+ return dict_[self.key] # type: ignore[no-any-return]
else:
try:
state = instance_state(instance)
except AttributeError as err:
raise orm_exc.UnmappedInstanceError(instance) from err
- return self.impl.get(state, dict_)
+ return self.impl.get(state, dict_) # type: ignore[no-any-return]
-HasEntityNamespace = namedtuple("HasEntityNamespace", ["entity_namespace"])
-HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False
+@dataclasses.dataclass(frozen=True)
+class AdHocHasEntityNamespace:
+ # py37 compat, no slots=True on dataclass
+ __slots__ = ("entity_namespace",)
+ entity_namespace: _ExternalEntityType[Any]
+ is_mapper: ClassVar[bool] = False
+ is_aliased_class: ClassVar[bool] = False
def create_proxied_attribute(
@@ -455,7 +569,7 @@ def create_proxied_attribute(
# TODO: can move this to descriptor_props if the need for this
# function is removed from ext/hybrid.py
- class Proxy(QueryableAttribute):
+ class Proxy(QueryableAttribute[Any]):
"""Presents the :class:`.QueryableAttribute` interface as a
proxy on top of a Python descriptor / :class:`.PropComparator`
combination.
@@ -464,6 +578,10 @@ def create_proxied_attribute(
_extra_criteria = ()
+ # the attribute error catches inside of __getattr__ basically create a
+ # singularity if you try putting slots on this too
+ # __slots__ = ("descriptor", "original_property", "_comparator")
+
def __init__(
self,
class_,
@@ -480,7 +598,15 @@ def create_proxied_attribute(
self.original_property = original_property
self._comparator = comparator
self._adapt_to_entity = adapt_to_entity
- self.__doc__ = doc
+ self._doc = self.__doc__ = doc
+
+ @property
+ def _parententity(self):
+ return inspection.inspect(self.class_)
+
+ @property
+ def parent(self):
+ return inspection.inspect(self.class_)
_is_internal_proxy = True
@@ -497,17 +623,13 @@ def create_proxied_attribute(
)
@property
- def _parententity(self):
- return inspection.inspect(self.class_, raiseerr=False)
-
- @property
def _entity_namespace(self):
if hasattr(self._comparator, "_parententity"):
return self._comparator._parententity
else:
# used by hybrid attributes which try to remain
# agnostic of any ORM concepts like mappers
- return HasEntityNamespace(self.class_)
+ return AdHocHasEntityNamespace(self.class_)
@property
def property(self):
@@ -552,12 +674,22 @@ def create_proxied_attribute(
else:
return retval
- def __str__(self):
- return "%s.%s" % (self.class_.__name__, self.key)
+ def __str__(self) -> str:
+ return f"{self.class_.__name__}.{self.key}"
def __getattr__(self, attribute):
"""Delegate __getattr__ to the original descriptor and/or
comparator."""
+
+ # this is unfortunately very complicated, and is easily prone
+ # to recursion overflows when implementations of related
+ # __getattr__ schemes are changed
+
+ try:
+ return util.MemoizedSlots.__getattr__(self, attribute)
+ except AttributeError:
+ pass
+
try:
return getattr(descriptor, attribute)
except AttributeError as err:
@@ -602,7 +734,7 @@ OP_BULK_REPLACE = util.symbol("BULK_REPLACE")
OP_MODIFIED = util.symbol("MODIFIED")
-class AttributeEvent:
+class AttributeEventToken:
"""A token propagated throughout the course of a chain of attribute
events.
@@ -619,7 +751,8 @@ class AttributeEvent:
event handlers, and is used to control the propagation of operations
across two mutually-dependent attributes.
- .. versionadded:: 0.9.0
+ .. versionchanged:: 2.0 Changed the name from ``AttributeEvent``
+ to ``AttributeEventToken``.
:attribute impl: The :class:`.AttributeImpl` which is the current event
initiator.
@@ -639,7 +772,7 @@ class AttributeEvent:
def __eq__(self, other):
return (
- isinstance(other, AttributeEvent)
+ isinstance(other, AttributeEventToken)
and other.impl is self.impl
and other.op == self.op
)
@@ -652,28 +785,37 @@ class AttributeEvent:
return self.impl.hasparent(state)
-Event = AttributeEvent
+AttributeEvent = AttributeEventToken # legacy
+Event = AttributeEventToken # legacy
class AttributeImpl:
"""internal implementation for instrumented attributes."""
collection: bool
+ default_accepts_scalar_loader: bool
+ uses_objects: bool
+ supports_population: bool
+ dynamic: bool
+
+ _replace_token: AttributeEventToken
+ _remove_token: AttributeEventToken
+ _append_token: AttributeEventToken
def __init__(
self,
- class_,
- key,
- callable_,
- dispatch,
- trackparent=False,
- compare_function=None,
- active_history=False,
- parent_token=None,
- load_on_unexpire=True,
- send_modified_events=True,
- accepts_scalar_loader=None,
- **kwargs,
+ class_: _ExternalEntityType[_O],
+ key: str,
+ callable_: _LoaderCallable,
+ dispatch: _Dispatch[QueryableAttribute[Any]],
+ trackparent: bool = False,
+ compare_function: Optional[Callable[..., bool]] = None,
+ active_history: bool = False,
+ parent_token: Optional[AttributeEventToken] = None,
+ load_on_unexpire: bool = True,
+ send_modified_events: bool = True,
+ accepts_scalar_loader: Optional[bool] = None,
+ **kwargs: Any,
):
r"""Construct an AttributeImpl.
@@ -743,7 +885,7 @@ class AttributeImpl:
self.dispatch._active_history = True
self.load_on_unexpire = load_on_unexpire
- self._modified_token = Event(self, OP_MODIFIED)
+ self._modified_token = AttributeEventToken(self, OP_MODIFIED)
__slots__ = (
"class_",
@@ -760,8 +902,8 @@ class AttributeImpl:
"_deferred_history",
)
- def __str__(self):
- return "%s.%s" % (self.class_.__name__, self.key)
+ def __str__(self) -> str:
+ return f"{self.class_.__name__}.{self.key}"
def _get_active_history(self):
"""Backwards compat for impl.active_history"""
@@ -773,7 +915,9 @@ class AttributeImpl:
active_history = property(_get_active_history, _set_active_history)
- def hasparent(self, state, optimistic=False):
+ def hasparent(
+ self, state: InstanceState[Any], optimistic: bool = False
+ ) -> bool:
"""Return the boolean value of a `hasparent` flag attached to
the given state.
@@ -796,7 +940,12 @@ class AttributeImpl:
state.parents.get(id(self.parent_token), optimistic) is not False
)
- def sethasparent(self, state, parent_state, value):
+ def sethasparent(
+ self,
+ state: InstanceState[Any],
+ parent_state: InstanceState[Any],
+ value: bool,
+ ) -> None:
"""Set a boolean flag on the given item corresponding to
whether or not it is attached to a parent object via the
attribute represented by this ``InstrumentedAttribute``.
@@ -839,11 +988,16 @@ class AttributeImpl:
self,
state: InstanceState[Any],
dict_: _InstanceDict,
- passive=PASSIVE_OFF,
+ passive: PassiveFlag = PASSIVE_OFF,
) -> History:
raise NotImplementedError()
- def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ def get_all_pending(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_NO_INITIALIZE,
+ ) -> _AllPendingType:
"""Return a list of tuples of (state, obj)
for all objects in this attribute's current state
+ history.
@@ -861,7 +1015,9 @@ class AttributeImpl:
"""
raise NotImplementedError()
- def _default_value(self, state, dict_):
+ def _default_value(
+ self, state: InstanceState[Any], dict_: _InstanceDict
+ ) -> Any:
"""Produce an empty value for an uninitialized scalar attribute."""
assert self.key not in dict_, (
@@ -877,7 +1033,12 @@ class AttributeImpl:
return value
- def get(self, state, dict_, passive=PASSIVE_OFF):
+ def get(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Any:
"""Retrieve a value from the given object.
If a callable is assembled on this object's attribute, and
passive is False, the callable will be executed and the
@@ -917,7 +1078,9 @@ class AttributeImpl:
else:
return self._default_value(state, dict_)
- def _fire_loader_callables(self, state, key, passive):
+ def _fire_loader_callables(
+ self, state: InstanceState[Any], key: str, passive: PassiveFlag
+ ) -> Any:
if (
self.accepts_scalar_loader
and self.load_on_unexpire
@@ -932,15 +1095,36 @@ class AttributeImpl:
else:
return ATTR_EMPTY
- def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def append(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
self.set(state, dict_, value, initiator, passive=passive)
- def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def remove(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
self.set(
state, dict_, None, initiator, passive=passive, check_old=value
)
- def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def pop(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
self.set(
state,
dict_,
@@ -953,17 +1137,25 @@ class AttributeImpl:
def set(
self,
- state,
- dict_,
- value,
- initiator,
- passive=PASSIVE_OFF,
- check_old=None,
- pop=False,
- ):
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ check_old: Any = None,
+ pop: bool = False,
+ ) -> None:
+ raise NotImplementedError()
+
+ def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
raise NotImplementedError()
- def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
+ def get_committed_value(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Any:
"""return the unchanged value of this attribute"""
if self.key in state.committed_state:
@@ -996,10 +1188,12 @@ class ScalarAttributeImpl(AttributeImpl):
def __init__(self, *arg, **kw):
super(ScalarAttributeImpl, self).__init__(*arg, **kw)
- self._replace_token = self._append_token = Event(self, OP_REPLACE)
- self._remove_token = Event(self, OP_REMOVE)
+ self._replace_token = self._append_token = AttributeEventToken(
+ self, OP_REPLACE
+ )
+ self._remove_token = AttributeEventToken(self, OP_REMOVE)
- def delete(self, state, dict_):
+ def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
if self.dispatch._active_history:
old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
else:
@@ -1042,11 +1236,11 @@ class ScalarAttributeImpl(AttributeImpl):
state: InstanceState[Any],
dict_: Dict[str, Any],
value: Any,
- initiator: Optional[Event],
+ initiator: Optional[AttributeEventToken],
passive: PassiveFlag = PASSIVE_OFF,
check_old: Optional[object] = None,
pop: bool = False,
- ):
+ ) -> None:
if self.dispatch._active_history:
old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
else:
@@ -1059,21 +1253,30 @@ class ScalarAttributeImpl(AttributeImpl):
state._modified_event(dict_, self, old)
dict_[self.key] = value
- def fire_replace_event(self, state, dict_, value, previous, initiator):
+ def fire_replace_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: _T,
+ previous: Any,
+ initiator: Optional[AttributeEventToken],
+ ) -> _T:
for fn in self.dispatch.set:
value = fn(
state, value, previous, initiator or self._replace_token
)
return value
- def fire_remove_event(self, state, dict_, value, initiator):
+ def fire_remove_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ ) -> None:
for fn in self.dispatch.remove:
fn(state, value, initiator or self._remove_token)
- @property
- def type(self):
- self.property.columns[0].type
-
class ScalarObjectAttributeImpl(ScalarAttributeImpl):
"""represents a scalar-holding InstrumentedAttribute,
@@ -1090,7 +1293,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
__slots__ = ()
- def delete(self, state, dict_):
+ def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
if self.dispatch._active_history:
old = self.get(
state,
@@ -1122,7 +1325,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
):
raise AttributeError("%s object does not have a value" % self)
- def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> History:
if self.key in dict_:
current = dict_[self.key]
else:
@@ -1152,7 +1360,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
self, state, current, original=original
)
- def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ def get_all_pending(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_NO_INITIALIZE,
+ ) -> _AllPendingType:
if self.key in dict_:
current = dict_[self.key]
elif passive & CALLABLES_OK:
@@ -1160,6 +1373,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
else:
return []
+ ret: _AllPendingType
+
# can't use __hash__(), can't use __eq__() here
if (
current is not None
@@ -1184,14 +1399,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
def set(
self,
- state,
- dict_,
- value,
- initiator,
- passive=PASSIVE_OFF,
- check_old=None,
- pop=False,
- ):
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ check_old: Any = None,
+ pop: bool = False,
+ ) -> None:
"""Set a value on the given InstanceState."""
if self.dispatch._active_history:
@@ -1227,7 +1442,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
value = self.fire_replace_event(state, dict_, value, old, initiator)
dict_[self.key] = value
- def fire_remove_event(self, state, dict_, value, initiator):
+ def fire_remove_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ ) -> None:
if self.trackparent and value not in (
None,
PASSIVE_NO_RESULT,
@@ -1240,7 +1461,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
state._modified_event(dict_, self, value)
- def fire_replace_event(self, state, dict_, value, previous, initiator):
+ def fire_replace_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: _T,
+ previous: Any,
+ initiator: Optional[AttributeEventToken],
+ ) -> _T:
if self.trackparent:
if previous is not value and previous not in (
None,
@@ -1263,7 +1491,64 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
return value
-class CollectionAttributeImpl(AttributeImpl):
+class HasCollectionAdapter:
+ __slots__ = ()
+
+ def _dispose_previous_collection(
+ self,
+ state: InstanceState[Any],
+ collection: _AdaptedCollectionProtocol,
+ adapter: CollectionAdapter,
+ fire_event: bool,
+ ) -> None:
+ raise NotImplementedError()
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
+ ...
+
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
+ raise NotImplementedError()
+
+
+if TYPE_CHECKING:
+
+ def _is_collection_attribute_impl(
+ impl: AttributeImpl,
+ ) -> TypeGuard[CollectionAttributeImpl]:
+ ...
+
+else:
+ _is_collection_attribute_impl = operator.attrgetter("collection")
+
+
+class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
"""A collection-holding attribute that instruments changes in membership.
Only handles collections of instrumented objects.
@@ -1275,12 +1560,14 @@ class CollectionAttributeImpl(AttributeImpl):
"""
- default_accepts_scalar_loader = False
uses_objects = True
- supports_population = True
collection = True
+ default_accepts_scalar_loader = False
+ supports_population = True
dynamic = False
+ _bulk_replace_token: AttributeEventToken
+
__slots__ = (
"copy",
"collection_factory",
@@ -1316,9 +1603,9 @@ class CollectionAttributeImpl(AttributeImpl):
copy_function = self.__copy
self.copy = copy_function
self.collection_factory = typecallable
- self._append_token = Event(self, OP_APPEND)
- self._remove_token = Event(self, OP_REMOVE)
- self._bulk_replace_token = Event(self, OP_BULK_REPLACE)
+ self._append_token = AttributeEventToken(self, OP_APPEND)
+ self._remove_token = AttributeEventToken(self, OP_REMOVE)
+ self._bulk_replace_token = AttributeEventToken(self, OP_BULK_REPLACE)
self._duck_typed_as = util.duck_type_collection(
self.collection_factory()
)
@@ -1336,14 +1623,24 @@ class CollectionAttributeImpl(AttributeImpl):
def __copy(self, item):
return [y for y in collections.collection_adapter(item)]
- def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> History:
current = self.get(state, dict_, passive=passive)
if current is PASSIVE_NO_RESULT:
return HISTORY_BLANK
else:
return History.from_collection(self, state, current)
- def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ def get_all_pending(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PASSIVE_NO_INITIALIZE,
+ ) -> _AllPendingType:
# NOTE: passive is ignored here at the moment
if self.key not in dict_:
@@ -1383,7 +1680,13 @@ class CollectionAttributeImpl(AttributeImpl):
return [(instance_state(o), o) for o in current]
- def fire_append_event(self, state, dict_, value, initiator):
+ def fire_append_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: _T,
+ initiator: Optional[AttributeEventToken],
+ ) -> _T:
for fn in self.dispatch.append:
value = fn(state, value, initiator or self._append_token)
@@ -1394,13 +1697,24 @@ class CollectionAttributeImpl(AttributeImpl):
return value
- def fire_append_wo_mutation_event(self, state, dict_, value, initiator):
+ def fire_append_wo_mutation_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: _T,
+ initiator: Optional[AttributeEventToken],
+ ) -> _T:
for fn in self.dispatch.append_wo_mutation:
value = fn(state, value, initiator or self._append_token)
return value
- def fire_pre_remove_event(self, state, dict_, initiator):
+ def fire_pre_remove_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ initiator: Optional[AttributeEventToken],
+ ) -> None:
"""A special event used for pop() operations.
The "remove" event needs to have the item to be removed passed to
@@ -1411,7 +1725,13 @@ class CollectionAttributeImpl(AttributeImpl):
"""
state._modified_event(dict_, self, NO_VALUE, True)
- def fire_remove_event(self, state, dict_, value, initiator):
+ def fire_remove_event(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ ) -> None:
if self.trackparent and value is not None:
self.sethasparent(instance_state(value), state, False)
@@ -1420,7 +1740,7 @@ class CollectionAttributeImpl(AttributeImpl):
state._modified_event(dict_, self, NO_VALUE, True)
- def delete(self, state, dict_):
+ def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
if self.key not in dict_:
return
@@ -1433,7 +1753,9 @@ class CollectionAttributeImpl(AttributeImpl):
# del is a no-op if collection not present.
del dict_[self.key]
- def _default_value(self, state, dict_):
+ def _default_value(
+ self, state: InstanceState[Any], dict_: _InstanceDict
+ ) -> _AdaptedCollectionProtocol:
"""Produce an empty collection for an un-initialized attribute"""
assert self.key not in dict_, (
@@ -1448,7 +1770,9 @@ class CollectionAttributeImpl(AttributeImpl):
adapter._set_empty(user_data)
return user_data
- def _initialize_collection(self, state):
+ def _initialize_collection(
+ self, state: InstanceState[Any]
+ ) -> Tuple[CollectionAdapter, _AdaptedCollectionProtocol]:
adapter, collection = state.manager.initialize_collection(
self.key, state, self.collection_factory
@@ -1458,7 +1782,14 @@ class CollectionAttributeImpl(AttributeImpl):
return adapter, collection
- def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def append(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
collection = self.get_collection(state, dict_, passive=passive)
if collection is PASSIVE_NO_RESULT:
value = self.fire_append_event(state, dict_, value, initiator)
@@ -1467,9 +1798,18 @@ class CollectionAttributeImpl(AttributeImpl):
), "Collection was loaded during event handling."
state._get_pending_mutation(self.key).append(value)
else:
+ if TYPE_CHECKING:
+ assert isinstance(collection, CollectionAdapter)
collection.append_with_event(value, initiator)
- def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def remove(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
collection = self.get_collection(state, state.dict, passive=passive)
if collection is PASSIVE_NO_RESULT:
self.fire_remove_event(state, dict_, value, initiator)
@@ -1478,9 +1818,18 @@ class CollectionAttributeImpl(AttributeImpl):
), "Collection was loaded during event handling."
state._get_pending_mutation(self.key).remove(value)
else:
+ if TYPE_CHECKING:
+ assert isinstance(collection, CollectionAdapter)
collection.remove_with_event(value, initiator)
- def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ def pop(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken],
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> None:
try:
# TODO: better solution here would be to add
# a "popper" role to collections.py to complement
@@ -1491,15 +1840,15 @@ class CollectionAttributeImpl(AttributeImpl):
def set(
self,
- state,
- dict_,
- value,
- initiator=None,
- passive=PASSIVE_OFF,
- check_old=None,
- pop=False,
- _adapt=True,
- ):
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken] = None,
+ passive: PassiveFlag = PASSIVE_OFF,
+ check_old: Any = None,
+ pop: bool = False,
+ _adapt: bool = True,
+ ) -> None:
iterable = orig_iterable = value
# pulling a new collection first so that an adaptation exception does
@@ -1518,7 +1867,7 @@ class CollectionAttributeImpl(AttributeImpl):
and "None"
or iterable.__class__.__name__
)
- wanted = self._duck_typed_as.__name__
+ wanted = self._duck_typed_as.__name__ # type: ignore
raise TypeError(
"Incompatible collection type: %s is not %s-like"
% (given, wanted)
@@ -1560,8 +1909,12 @@ class CollectionAttributeImpl(AttributeImpl):
self._dispose_previous_collection(state, old, old_collection, True)
def _dispose_previous_collection(
- self, state, collection, adapter, fire_event
- ):
+ self,
+ state: InstanceState[Any],
+ collection: _AdaptedCollectionProtocol,
+ adapter: CollectionAdapter,
+ fire_event: bool,
+ ) -> None:
del collection._sa_adapter
# discarding old collection make sure it is not referenced in empty
@@ -1570,11 +1923,15 @@ class CollectionAttributeImpl(AttributeImpl):
if fire_event:
self.dispatch.dispose_collection(state, collection, adapter)
- def _invalidate_collection(self, collection: Collection) -> None:
+ def _invalidate_collection(
+ self, collection: _AdaptedCollectionProtocol
+ ) -> None:
adapter = getattr(collection, "_sa_adapter")
adapter.invalidated = True
- def set_committed_value(self, state, dict_, value):
+ def set_committed_value(
+ self, state: InstanceState[Any], dict_: _InstanceDict, value: Any
+ ) -> _AdaptedCollectionProtocol:
"""Set an attribute value on the given instance and 'commit' it."""
collection, user_data = self._initialize_collection(state)
@@ -1601,9 +1958,37 @@ class CollectionAttributeImpl(AttributeImpl):
return user_data
+ @overload
def get_collection(
- self, state, dict_, user_data=None, passive=PASSIVE_OFF
- ):
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
+ ...
+
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: PassiveFlag = PASSIVE_OFF,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
"""Retrieve the CollectionAdapter associated with the given state.
if user_data is None, retrieves it from the state using normal
@@ -1612,14 +1997,18 @@ class CollectionAttributeImpl(AttributeImpl):
"""
if user_data is None:
- user_data = self.get(state, dict_, passive=passive)
- if user_data is PASSIVE_NO_RESULT:
- return user_data
+ fetch_user_data = self.get(state, dict_, passive=passive)
+ if fetch_user_data is LoaderCallableStatus.PASSIVE_NO_RESULT:
+ return fetch_user_data
+ else:
+ user_data = cast("_AdaptedCollectionProtocol", fetch_user_data)
return user_data._sa_adapter
-def backref_listeners(attribute, key, uselist):
+def backref_listeners(
+ attribute: QueryableAttribute[Any], key: str, uselist: bool
+) -> None:
"""Apply listeners to synchronize a two-way relationship."""
# use easily recognizable names for stack traces.
@@ -1695,7 +2084,7 @@ def backref_listeners(attribute, key, uselist):
check_append_token = child_impl._append_token
check_bulk_replace_token = (
child_impl._bulk_replace_token
- if child_impl.collection
+ if _is_collection_attribute_impl(child_impl)
else None
)
@@ -1728,7 +2117,9 @@ def backref_listeners(attribute, key, uselist):
# tokens to test for a recursive loop.
check_append_token = child_impl._append_token
check_bulk_replace_token = (
- child_impl._bulk_replace_token if child_impl.collection else None
+ child_impl._bulk_replace_token
+ if _is_collection_attribute_impl(child_impl)
+ else None
)
if (
@@ -1756,6 +2147,8 @@ def backref_listeners(attribute, key, uselist):
)
child_impl = child_state.manager[key].impl
+ check_replace_token: Optional[AttributeEventToken]
+
# tokens to test for a recursive loop.
if not child_impl.collection and not child_impl.dynamic:
check_remove_token = child_impl._remove_token
@@ -1765,7 +2158,7 @@ def backref_listeners(attribute, key, uselist):
check_remove_token = child_impl._remove_token
check_replace_token = (
child_impl._bulk_replace_token
- if child_impl.collection
+ if _is_collection_attribute_impl(child_impl)
else None
)
check_for_dupes_on_remove = False
@@ -1848,10 +2241,10 @@ class History(NamedTuple):
unchanged: Union[Tuple[()], List[Any]]
deleted: Union[Tuple[()], List[Any]]
- def __bool__(self):
+ def __bool__(self) -> bool:
return self != HISTORY_BLANK
- def empty(self):
+ def empty(self) -> bool:
"""Return True if this :class:`.History` has no changes
and no existing, unchanged state.
@@ -1859,29 +2252,29 @@ class History(NamedTuple):
return not bool((self.added or self.deleted) or self.unchanged)
- def sum(self):
+ def sum(self) -> Sequence[Any]:
"""Return a collection of added + unchanged + deleted."""
return (
(self.added or []) + (self.unchanged or []) + (self.deleted or [])
)
- def non_deleted(self):
+ def non_deleted(self) -> Sequence[Any]:
"""Return a collection of added + unchanged."""
return (self.added or []) + (self.unchanged or [])
- def non_added(self):
+ def non_added(self) -> Sequence[Any]:
"""Return a collection of unchanged + deleted."""
return (self.unchanged or []) + (self.deleted or [])
- def has_changes(self):
+ def has_changes(self) -> bool:
"""Return True if this :class:`.History` has changes."""
return bool(self.added or self.deleted)
- def as_state(self):
+ def as_state(self) -> History:
return History(
[
(c is not None) and instance_state(c) or None
@@ -1898,9 +2291,16 @@ class History(NamedTuple):
)
@classmethod
- def from_scalar_attribute(cls, attribute, state, current):
+ def from_scalar_attribute(
+ cls,
+ attribute: ScalarAttributeImpl,
+ state: InstanceState[Any],
+ current: Any,
+ ) -> History:
original = state.committed_state.get(attribute.key, _NO_HISTORY)
+ deleted: Union[Tuple[()], List[Any]]
+
if original is _NO_HISTORY:
if current is NO_VALUE:
return cls((), (), ())
@@ -1933,8 +2333,14 @@ class History(NamedTuple):
@classmethod
def from_object_attribute(
- cls, attribute, state, current, original=_NO_HISTORY
- ):
+ cls,
+ attribute: ScalarObjectAttributeImpl,
+ state: InstanceState[Any],
+ current: Any,
+ original: Any = _NO_HISTORY,
+ ) -> History:
+ deleted: Union[Tuple[()], List[Any]]
+
if original is _NO_HISTORY:
original = state.committed_state.get(attribute.key, _NO_HISTORY)
@@ -1965,7 +2371,12 @@ class History(NamedTuple):
return cls([current], (), deleted)
@classmethod
- def from_collection(cls, attribute, state, current):
+ def from_collection(
+ cls,
+ attribute: CollectionAttributeImpl,
+ state: InstanceState[Any],
+ current: Any,
+ ) -> History:
original = state.committed_state.get(attribute.key, _NO_HISTORY)
if current is NO_VALUE:
return cls((), (), ())
@@ -1999,7 +2410,9 @@ class History(NamedTuple):
HISTORY_BLANK = History((), (), ())
-def get_history(obj, key, passive=PASSIVE_OFF):
+def get_history(
+ obj: object, key: str, passive: PassiveFlag = PASSIVE_OFF
+) -> History:
"""Return a :class:`.History` record for the given object
and attribute key.
@@ -2037,36 +2450,47 @@ def get_history(obj, key, passive=PASSIVE_OFF):
return get_state_history(instance_state(obj), key, passive)
-def get_state_history(state, key, passive=PASSIVE_OFF):
+def get_state_history(
+ state: InstanceState[Any], key: str, passive: PassiveFlag = PASSIVE_OFF
+) -> History:
return state.get_history(key, passive)
-def has_parent(cls, obj, key, optimistic=False):
+def has_parent(
+ cls: Type[_O], obj: _O, key: str, optimistic: bool = False
+) -> bool:
"""TODO"""
manager = manager_of_class(cls)
state = instance_state(obj)
return manager.has_parent(state, key, optimistic)
-def register_attribute(class_, key, **kw):
- comparator = kw.pop("comparator", None)
- parententity = kw.pop("parententity", None)
- doc = kw.pop("doc", None)
- desc = register_descriptor(class_, key, comparator, parententity, doc=doc)
+def register_attribute(
+ class_: Type[_O],
+ key: str,
+ *,
+ comparator: interfaces.PropComparator[_T],
+ parententity: _InternalEntityType[_O],
+ doc: Optional[str] = None,
+ **kw: Any,
+) -> InstrumentedAttribute[_T]:
+ desc = register_descriptor(
+ class_, key, comparator=comparator, parententity=parententity, doc=doc
+ )
register_attribute_impl(class_, key, **kw)
return desc
def register_attribute_impl(
- class_,
- key,
- uselist=False,
- callable_=None,
- useobject=False,
- impl_class=None,
- backref=None,
- **kw,
-):
+ class_: Type[_O],
+ key: str,
+ uselist: bool = False,
+ callable_: Optional[_LoaderCallable] = None,
+ useobject: bool = False,
+ impl_class: Optional[Type[AttributeImpl]] = None,
+ backref: Optional[str] = None,
+ **kw: Any,
+) -> InstrumentedAttribute[Any]:
manager = manager_of_class(class_)
if uselist:
@@ -2077,10 +2501,18 @@ def register_attribute_impl(
else:
typecallable = kw.pop("typecallable", None)
- dispatch = manager[key].dispatch
+ dispatch = cast(
+ "_Dispatch[QueryableAttribute[Any]]", manager[key].dispatch
+ ) # noqa: E501
+
+ impl: AttributeImpl
if impl_class:
- impl = impl_class(class_, key, typecallable, dispatch, **kw)
+ # TODO: this appears to be the DynamicAttributeImpl constructor
+ # which is hardcoded
+ impl = cast("Type[DynamicAttributeImpl]", impl_class)(
+ class_, key, typecallable, dispatch, **kw
+ )
elif uselist:
impl = CollectionAttributeImpl(
class_, key, callable_, dispatch, typecallable=typecallable, **kw
@@ -2102,8 +2534,13 @@ def register_attribute_impl(
def register_descriptor(
- class_, key, comparator=None, parententity=None, doc=None
-):
+ class_: Type[Any],
+ key: str,
+ *,
+ comparator: interfaces.PropComparator[_T],
+ parententity: _InternalEntityType[Any],
+ doc: Optional[str] = None,
+) -> InstrumentedAttribute[_T]:
manager = manager_of_class(class_)
descriptor = InstrumentedAttribute(
@@ -2116,11 +2553,11 @@ def register_descriptor(
return descriptor
-def unregister_attribute(class_, key):
+def unregister_attribute(class_: Type[Any], key: str) -> None:
manager_of_class(class_).uninstrument_attribute(key)
-def init_collection(obj, key):
+def init_collection(obj: object, key: str) -> CollectionAdapter:
"""Initialize a collection attribute and return the collection adapter.
This function is used to provide direct access to collection internals
@@ -2143,7 +2580,9 @@ def init_collection(obj, key):
return init_state_collection(state, dict_, key)
-def init_state_collection(state, dict_, key):
+def init_state_collection(
+ state: InstanceState[Any], dict_: _InstanceDict, key: str
+) -> CollectionAdapter:
"""Initialize a collection attribute and return the collection adapter.
Discards any existing collection which may be there.
@@ -2151,6 +2590,9 @@ def init_state_collection(state, dict_, key):
"""
attr = state.manager[key].impl
+ if TYPE_CHECKING:
+ assert isinstance(attr, HasCollectionAdapter)
+
old = dict_.pop(key, None) # discard old collection
if old is not None:
old_collection = old._sa_adapter
@@ -2182,7 +2624,12 @@ def set_committed_value(instance, key, value):
state.manager[key].impl.set_committed_value(state, dict_, value)
-def set_attribute(instance, key, value, initiator=None):
+def set_attribute(
+ instance: object,
+ key: str,
+ value: Any,
+ initiator: Optional[AttributeEventToken] = None,
+) -> None:
"""Set the value of an attribute, firing history events.
This function may be used regardless of instrumentation
@@ -2211,7 +2658,7 @@ def set_attribute(instance, key, value, initiator=None):
state.manager[key].impl.set(state, dict_, value, initiator)
-def get_attribute(instance, key):
+def get_attribute(instance: object, key: str) -> Any:
"""Get the value of an attribute, firing any callables required.
This function may be used regardless of instrumentation
@@ -2225,7 +2672,7 @@ def get_attribute(instance, key):
return state.manager[key].impl.get(state, dict_)
-def del_attribute(instance, key):
+def del_attribute(instance: object, key: str) -> None:
"""Delete the value of an attribute, firing history events.
This function may be used regardless of instrumentation
@@ -2239,7 +2686,7 @@ def del_attribute(instance, key):
state.manager[key].impl.delete(state, dict_)
-def flag_modified(instance, key):
+def flag_modified(instance: object, key: str) -> None:
"""Mark an attribute on an instance as 'modified'.
This sets the 'modified' flag on the instance and
@@ -2262,7 +2709,7 @@ def flag_modified(instance, key):
state._modified_event(dict_, impl, NO_VALUE, is_userland=True)
-def flag_dirty(instance):
+def flag_dirty(instance: object) -> None:
"""Mark an instance as 'dirty' without any specific attribute mentioned.
This is a special operation that will allow the object to travel through
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index 367a5332d..0ace9b1cb 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -38,14 +38,15 @@ from ..util.typing import Literal
from ..util.typing import Self
if typing.TYPE_CHECKING:
+ from ._typing import _ExternalEntityType
from ._typing import _InternalEntityType
from .attributes import InstrumentedAttribute
from .instrumentation import ClassManager
from .mapper import Mapper
from .state import InstanceState
+ from .util import AliasedClass
from ..sql._typing import _InfoType
-
_T = TypeVar("_T", bound=Any)
_O = TypeVar("_O", bound=object)
@@ -267,10 +268,22 @@ def _assertions(
if TYPE_CHECKING:
- def manager_of_class(cls: Type[Any]) -> ClassManager:
+ def manager_of_class(cls: Type[_O]) -> ClassManager[_O]:
+ ...
+
+ @overload
+ def opt_manager_of_class(cls: AliasedClass[Any]) -> None:
...
- def opt_manager_of_class(cls: Type[Any]) -> Optional[ClassManager]:
+ @overload
+ def opt_manager_of_class(
+ cls: _ExternalEntityType[_O],
+ ) -> Optional[ClassManager[_O]]:
+ ...
+
+ def opt_manager_of_class(
+ cls: _ExternalEntityType[_O],
+ ) -> Optional[ClassManager[_O]]:
...
def instance_state(instance: _O) -> InstanceState[_O]:
@@ -719,7 +732,7 @@ class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly):
...
def __get__(
- self, instance: object, owner: Any
+ self, instance: Optional[object], owner: Any
) -> Union[InstrumentedAttribute[_T], _T]:
...
@@ -729,10 +742,10 @@ class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly):
def __set__(
self, instance: Any, value: Union[SQLCoreOperations[_T], _T]
- ):
+ ) -> None:
...
- def __delete__(self, instance: Any):
+ def __delete__(self, instance: Any) -> None:
...
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index 717f1d0d6..da0da0fcf 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -4,7 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
"""Support for collections of mapped entities.
@@ -109,17 +109,34 @@ import operator
import threading
import typing
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
import weakref
from .. import exc as sa_exc
from .. import util
from ..util.compat import inspect_getfullargspec
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from .attributes import CollectionAttributeImpl
from .mapped_collection import attribute_mapped_collection
from .mapped_collection import column_mapped_collection
from .mapped_collection import mapped_collection
from .mapped_collection import MappedCollection # noqa: F401
+ from .state import InstanceState
+
__all__ = [
"collection",
@@ -132,6 +149,28 @@ __all__ = [
__instrumentation_mutex = threading.Lock()
+_CollectionFactoryType = Callable[[], "_AdaptedCollectionProtocol"]
+
+_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
+_COL = TypeVar("_COL", bound="Collection[Any]")
+_FN = TypeVar("_FN", bound="Callable[..., Any]")
+
+
+class _CollectionConverterProtocol(Protocol):
+ def __call__(self, collection: _COL) -> _COL:
+ ...
+
+
+class _AdaptedCollectionProtocol(Protocol):
+ _sa_adapter: CollectionAdapter
+ _sa_appender: Callable[..., Any]
+ _sa_remover: Callable[..., Any]
+ _sa_iterator: Callable[..., Iterable[Any]]
+ _sa_converter: _CollectionConverterProtocol
+
+
class collection:
"""Decorators for entity collection classes.
@@ -396,8 +435,13 @@ class collection:
return decorator
-collection_adapter = operator.attrgetter("_sa_adapter")
-"""Fetch the :class:`.CollectionAdapter` for a collection."""
+if TYPE_CHECKING:
+
+ def collection_adapter(collection: Collection[Any]) -> CollectionAdapter:
+ """Fetch the :class:`.CollectionAdapter` for a collection."""
+
+else:
+ collection_adapter = operator.attrgetter("_sa_adapter")
class CollectionAdapter:
@@ -423,10 +467,33 @@ class CollectionAdapter:
"empty",
)
- def __init__(self, attr, owner_state, data):
+ attr: CollectionAttributeImpl
+ _key: str
+
+ # this is actually a weakref; see note in constructor
+ _data: Callable[..., _AdaptedCollectionProtocol]
+
+ owner_state: InstanceState[Any]
+ _converter: _CollectionConverterProtocol
+ invalidated: bool
+ empty: bool
+
+ def __init__(
+ self,
+ attr: CollectionAttributeImpl,
+ owner_state: InstanceState[Any],
+ data: _AdaptedCollectionProtocol,
+ ):
self.attr = attr
self._key = attr.key
- self._data = weakref.ref(data)
+
+ # this weakref stays referenced throughout the lifespan of
+ # CollectionAdapter. so while the weakref can return None, this
+ # is realistically only during garbage collection of this object, so
+ # we type this as a callable that returns _AdaptedCollectionProtocol
+ # in all cases.
+ self._data = weakref.ref(data) # type: ignore
+
self.owner_state = owner_state
data._sa_adapter = self
self._converter = data._sa_converter
@@ -437,7 +504,7 @@ class CollectionAdapter:
util.warn("This collection has been invalidated.")
@property
- def data(self):
+ def data(self) -> _AdaptedCollectionProtocol:
"The entity collection being adapted."
return self._data()
@@ -634,7 +701,10 @@ class CollectionAdapter:
def __setstate__(self, d):
self._key = d["key"]
self.owner_state = d["owner_state"]
- self._data = weakref.ref(d["data"])
+
+ # see note in constructor regarding this type: ignore
+ self._data = weakref.ref(d["data"]) # type: ignore
+
self._converter = d["data"]._sa_converter
d["data"]._sa_adapter = self
self.invalidated = d["invalidated"]
@@ -682,7 +752,9 @@ def bulk_replace(values, existing_adapter, new_adapter, initiator=None):
existing_adapter.fire_remove_event(member, initiator=initiator)
-def prepare_instrumentation(factory):
+def prepare_instrumentation(
+ factory: Union[Type[Collection[Any]], _CollectionFactoryType],
+) -> _CollectionFactoryType:
"""Prepare a callable for future use as a collection class factory.
Given a collection class factory (either a type or no-arg callable),
@@ -693,18 +765,30 @@ def prepare_instrumentation(factory):
into the run-time behavior of collection_class=InstrumentedList.
"""
+
+ impl_factory: _CollectionFactoryType
+
# Convert a builtin to 'Instrumented*'
if factory in __canned_instrumentation:
- factory = __canned_instrumentation[factory]
+ impl_factory = __canned_instrumentation[factory]
+ else:
+ impl_factory = cast(_CollectionFactoryType, factory)
+
+ cls: Union[_CollectionFactoryType, Type[Collection[Any]]]
# Create a specimen
- cls = type(factory())
+ cls = type(impl_factory())
# Did factory callable return a builtin?
if cls in __canned_instrumentation:
- # Wrap it so that it returns our 'Instrumented*'
- factory = __converting_factory(cls, factory)
- cls = factory()
+
+ # if so, just convert.
+ # in previous major releases, this codepath wasn't working and was
+ # not covered by tests. prior to that it supplied a "wrapper"
+ # function that would return the class, though the rationale for this
+ # case is not known
+ impl_factory = __canned_instrumentation[cls]
+ cls = type(impl_factory())
# Instrument the class if needed.
if __instrumentation_mutex.acquire():
@@ -714,26 +798,7 @@ def prepare_instrumentation(factory):
finally:
__instrumentation_mutex.release()
- return factory
-
-
-def __converting_factory(specimen_cls, original_factory):
- """Return a wrapper that converts a "canned" collection like
- set, dict, list into the Instrumented* version.
-
- """
-
- instrumented_cls = __canned_instrumentation[specimen_cls]
-
- def wrapper():
- collection = original_factory()
- return instrumented_cls(collection)
-
- # often flawed but better than nothing
- wrapper.__name__ = "%sWrapper" % original_factory.__name__
- wrapper.__doc__ = original_factory.__doc__
-
- return wrapper
+ return impl_factory
def _instrument_class(cls):
@@ -763,8 +828,8 @@ def _locate_roles_and_methods(cls):
"""
- roles = {}
- methods = {}
+ roles: Dict[str, str] = {}
+ methods: Dict[str, Tuple[Optional[str], Optional[int], Optional[str]]] = {}
for supercls in cls.__mro__:
for name, method in vars(supercls).items():
@@ -784,7 +849,9 @@ def _locate_roles_and_methods(cls):
# transfer instrumentation requests from decorated function
# to the combined queue
- before, after = None, None
+ before: Optional[Tuple[str, int]] = None
+ after: Optional[str] = None
+
if hasattr(method, "_sa_instrument_before"):
op, argument = method._sa_instrument_before
assert op in ("fire_append_event", "fire_remove_event")
@@ -809,6 +876,7 @@ def _setup_canned_roles(cls, roles, methods):
"""
collection_type = util.duck_type_collection(cls)
if collection_type in __interfaces:
+ assert collection_type is not None
canned_roles, decorators = __interfaces[collection_type]
for role, name in canned_roles.items():
roles.setdefault(role, name)
@@ -934,9 +1002,9 @@ def _instrument_membership_mutator(method, before, argument, after):
getattr(executor, after)(res, initiator)
return res
- wrapper._sa_instrumented = True
+ wrapper._sa_instrumented = True # type: ignore[attr-defined]
if hasattr(method, "_sa_instrument_role"):
- wrapper._sa_instrument_role = method._sa_instrument_role
+ wrapper._sa_instrument_role = method._sa_instrument_role # type: ignore[attr-defined] # noqa: E501
wrapper.__name__ = method.__name__
wrapper.__doc__ = method.__doc__
return wrapper
@@ -990,7 +1058,7 @@ def __before_pop(collection, _sa_initiator=None):
executor.fire_pre_remove_event(_sa_initiator)
-def _list_decorators():
+def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
"""Tailored instrumentation wrappers for any list-like class."""
def _tidy(fn):
@@ -1131,7 +1199,7 @@ def _list_decorators():
return l
-def _dict_decorators():
+def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
"""Tailored instrumentation wrappers for any dict-like mapping class."""
def _tidy(fn):
@@ -1255,7 +1323,7 @@ def _set_binops_check_loose(self: Any, obj: Any) -> bool:
)
-def _set_decorators():
+def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
"""Tailored instrumentation wrappers for any set-like class."""
def _tidy(fn):
@@ -1420,36 +1488,52 @@ def _set_decorators():
return l
-class InstrumentedList(list):
+class InstrumentedList(List[_T]):
"""An instrumented version of the built-in list."""
-class InstrumentedSet(set):
+class InstrumentedSet(Set[_T]):
"""An instrumented version of the built-in set."""
-class InstrumentedDict(dict):
+class InstrumentedDict(Dict[_KT, _VT]):
"""An instrumented version of the built-in dict."""
-__canned_instrumentation = {
- list: InstrumentedList,
- set: InstrumentedSet,
- dict: InstrumentedDict,
-}
-
-__interfaces = {
- list: (
- {"appender": "append", "remover": "remove", "iterator": "__iter__"},
- _list_decorators(),
- ),
- set: (
- {"appender": "add", "remover": "remove", "iterator": "__iter__"},
- _set_decorators(),
- ),
- # decorators are required for dicts and object collections.
- dict: ({"iterator": "values"}, _dict_decorators()),
-}
+__canned_instrumentation: util.immutabledict[
+ Any, _CollectionFactoryType
+] = util.immutabledict(
+ {
+ list: InstrumentedList,
+ set: InstrumentedSet,
+ dict: InstrumentedDict,
+ }
+)
+
+__interfaces: util.immutabledict[
+ Any,
+ Tuple[
+ Dict[str, str],
+ Dict[str, Callable[..., Any]],
+ ],
+] = util.immutabledict(
+ {
+ list: (
+ {
+ "appender": "append",
+ "remover": "remove",
+ "iterator": "__iter__",
+ },
+ _list_decorators(),
+ ),
+ set: (
+ {"appender": "add", "remover": "remove", "iterator": "__iter__"},
+ _set_decorators(),
+ ),
+ # decorators are required for dicts and object collections.
+ dict: ({"iterator": "values"}, _dict_decorators()),
+ }
+)
def __go(lcls):
diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py
index 63a37d0da..1b4f573b5 100644
--- a/lib/sqlalchemy/orm/dynamic.py
+++ b/lib/sqlalchemy/orm/dynamic.py
@@ -64,7 +64,9 @@ class DynaLoader(strategies.AbstractRelationshipLoader):
)
-class DynamicAttributeImpl(attributes.AttributeImpl):
+class DynamicAttributeImpl(
+ attributes.HasCollectionAdapter, attributes.AttributeImpl
+):
uses_objects = True
default_accepts_scalar_loader = False
supports_population = False
@@ -120,11 +122,11 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
@util.memoized_property
def _append_token(self):
- return attributes.Event(self, attributes.OP_APPEND)
+ return attributes.AttributeEventToken(self, attributes.OP_APPEND)
@util.memoized_property
def _remove_token(self):
- return attributes.Event(self, attributes.OP_REMOVE)
+ return attributes.AttributeEventToken(self, attributes.OP_REMOVE)
def fire_append_event(
self, state, dict_, value, initiator, collection_history=None
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
index 356958562..85b85215e 100644
--- a/lib/sqlalchemy/orm/instrumentation.py
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -4,7 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
"""Defines SQLAlchemy's system of class instrumentation.
@@ -35,14 +35,19 @@ from __future__ import annotations
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Collection
from typing import Dict
from typing import Generic
+from typing import Iterable
+from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
import weakref
from . import base
@@ -51,15 +56,21 @@ from . import exc
from . import interfaces
from . import state
from ._typing import _O
+from .attributes import _is_collection_attribute_impl
from .. import util
from ..event import EventTarget
from ..util import HasMemoized
+from ..util.typing import Literal
from ..util.typing import Protocol
if TYPE_CHECKING:
from ._typing import _RegistryType
+ from .attributes import AttributeImpl
from .attributes import InstrumentedAttribute
+ from .collections import _AdaptedCollectionProtocol
+ from .collections import _CollectionFactoryType
from .decl_base import _MapperConfig
+ from .events import InstanceEvents
from .mapper import Mapper
from .state import InstanceState
from ..event import dispatcher
@@ -74,7 +85,7 @@ class _ExpiredAttributeLoaderProto(Protocol):
state: state.InstanceState[Any],
toload: Set[str],
passive: base.PassiveFlag,
- ):
+ ) -> None:
...
@@ -91,7 +102,7 @@ class ClassManager(
):
"""Tracks state information at the class level."""
- dispatch: dispatcher[ClassManager]
+ dispatch: dispatcher[ClassManager[_O]]
MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
STATE_ATTR = base.DEFAULT_STATE_ATTR
@@ -108,8 +119,9 @@ class ClassManager(
declarative_scan: Optional[weakref.ref[_MapperConfig]] = None
registry: Optional[_RegistryType] = None
- @property
- @util.deprecated(
+ _bases: List[ClassManager[Any]]
+
+ @util.deprecated_property(
"1.4",
message="The ClassManager.deferred_scalar_loader attribute is now "
"named expired_attribute_loader",
@@ -117,7 +129,7 @@ class ClassManager(
def deferred_scalar_loader(self):
return self.expired_attribute_loader
- @deferred_scalar_loader.setter
+ @deferred_scalar_loader.setter # type: ignore[no-redef]
@util.deprecated(
"1.4",
message="The ClassManager.deferred_scalar_loader attribute is now "
@@ -138,18 +150,23 @@ class ClassManager(
self._bases = [
mgr
- for mgr in [
- opt_manager_of_class(base)
- for base in self.class_.__bases__
- if isinstance(base, type)
- ]
+ for mgr in cast(
+ "List[Optional[ClassManager[Any]]]",
+ [
+ opt_manager_of_class(base)
+ for base in self.class_.__bases__
+ if isinstance(base, type)
+ ],
+ )
if mgr is not None
]
for base_ in self._bases:
self.update(base_)
- self.dispatch._events._new_classmanager_instance(class_, self)
+ cast(
+ "InstanceEvents", self.dispatch._events
+ )._new_classmanager_instance(class_, self)
for basecls in class_.__mro__:
mgr = opt_manager_of_class(basecls)
@@ -263,7 +280,7 @@ class ClassManager(
"""
- found = {}
+ found: Dict[str, Any] = {}
# constraints:
# 1. yield keys in cls.__dict__ order
@@ -303,7 +320,7 @@ class ClassManager(
return key in self and self[key].impl is not None
- def _subclass_manager(self, cls):
+ def _subclass_manager(self, cls: Type[_T]) -> ClassManager[_T]:
"""Create a new ClassManager for a subclass of this ClassManager's
class.
@@ -321,7 +338,7 @@ class ClassManager(
self.install_member("__init__", self.new_init)
@util.memoized_property
- def _state_constructor(self):
+ def _state_constructor(self) -> Type[state.InstanceState[_O]]:
self.dispatch.first_init(self, self.class_)
return state.InstanceState
@@ -393,13 +410,15 @@ class ClassManager(
if manager:
manager.uninstrument_attribute(key, True)
- def unregister(self):
+ def unregister(self) -> None:
"""remove all instrumentation established by this ClassManager."""
for key in list(self.originals):
self.uninstall_member(key)
- self.mapper = self.dispatch = self.new_init = None
+ self.mapper = None # type: ignore
+ self.dispatch = None # type: ignore
+ self.new_init = None
self.info.clear()
for key in list(self):
@@ -409,7 +428,9 @@ class ClassManager(
if self.MANAGER_ATTR in self.class_.__dict__:
delattr(self.class_, self.MANAGER_ATTR)
- def install_descriptor(self, key, inst):
+ def install_descriptor(
+ self, key: str, inst: InstrumentedAttribute[Any]
+ ) -> None:
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
raise KeyError(
"%r: requested attribute name conflicts with "
@@ -417,10 +438,10 @@ class ClassManager(
)
setattr(self.class_, key, inst)
- def uninstall_descriptor(self, key):
+ def uninstall_descriptor(self, key: str) -> None:
delattr(self.class_, key)
- def install_member(self, key, implementation):
+ def install_member(self, key: str, implementation: Any) -> None:
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
raise KeyError(
"%r: requested attribute name conflicts with "
@@ -429,34 +450,41 @@ class ClassManager(
self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR))
setattr(self.class_, key, implementation)
- def uninstall_member(self, key):
+ def uninstall_member(self, key: str) -> None:
original = self.originals.pop(key, None)
if original is not DEL_ATTR:
setattr(self.class_, key, original)
else:
delattr(self.class_, key)
- def instrument_collection_class(self, key, collection_class):
+ def instrument_collection_class(
+ self, key: str, collection_class: Type[Collection[Any]]
+ ) -> _CollectionFactoryType:
return collections.prepare_instrumentation(collection_class)
- def initialize_collection(self, key, state, factory):
+ def initialize_collection(
+ self,
+ key: str,
+ state: InstanceState[_O],
+ factory: _CollectionFactoryType,
+ ) -> Tuple[collections.CollectionAdapter, _AdaptedCollectionProtocol]:
user_data = factory()
- adapter = collections.CollectionAdapter(
- self.get_impl(key), state, user_data
- )
+ impl = self.get_impl(key)
+ assert _is_collection_attribute_impl(impl)
+ adapter = collections.CollectionAdapter(impl, state, user_data)
return adapter, user_data
- def is_instrumented(self, key, search=False):
+ def is_instrumented(self, key: str, search: bool = False) -> bool:
if search:
return key in self
else:
return key in self.local_attrs
- def get_impl(self, key):
+ def get_impl(self, key: str) -> AttributeImpl:
return self[key].impl
@property
- def attributes(self):
+ def attributes(self) -> Iterable[Any]:
return iter(self.values())
# InstanceState management
@@ -466,22 +494,26 @@ class ClassManager(
if state is None:
state = self._state_constructor(instance, self)
self._state_setter(instance, state)
- return instance
+ return instance # type: ignore[no-any-return]
- def setup_instance(self, instance, state=None):
+ def setup_instance(
+ self, instance: _O, state: Optional[InstanceState[_O]] = None
+ ) -> None:
if state is None:
state = self._state_constructor(instance, self)
self._state_setter(instance, state)
- def teardown_instance(self, instance):
+ def teardown_instance(self, instance: _O) -> None:
delattr(instance, self.STATE_ATTR)
def _serialize(
- self, state: state.InstanceState, state_dict: Dict[str, Any]
+ self, state: InstanceState[_O], state_dict: Dict[str, Any]
) -> _SerializeManager:
return _SerializeManager(state, state_dict)
- def _new_state_if_none(self, instance):
+ def _new_state_if_none(
+ self, instance: _O
+ ) -> Union[Literal[False], InstanceState[_O]]:
"""Install a default InstanceState if none is present.
A private convenience method used by the __init__ decorator.
@@ -503,20 +535,20 @@ class ClassManager(
self._state_setter(instance, state)
return state
- def has_state(self, instance):
+ def has_state(self, instance: _O) -> bool:
return hasattr(instance, self.STATE_ATTR)
- def has_parent(self, state, key, optimistic=False):
+ def has_parent(
+ self, state: InstanceState[_O], key: str, optimistic: bool = False
+ ) -> bool:
"""TODO"""
return self.get_impl(key).hasparent(state, optimistic=optimistic)
- def __bool__(self):
+ def __bool__(self) -> bool:
"""All ClassManagers are non-zero regardless of attribute state."""
return True
- __nonzero__ = __bool__
-
- def __repr__(self):
+ def __repr__(self) -> str:
return "<%s of %r at %x>" % (
self.__class__.__name__,
self.class_,
@@ -558,9 +590,11 @@ class _SerializeManager:
manager.dispatch.unpickle(state, state_dict)
-class InstrumentationFactory:
+class InstrumentationFactory(EventTarget):
"""Factory for new ClassManager instances."""
+ dispatch: dispatcher[InstrumentationFactory]
+
def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]:
assert class_ is not None
assert opt_manager_of_class(class_) is None
@@ -589,11 +623,10 @@ class InstrumentationFactory:
def _check_conflicts(
self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]]
- ):
+ ) -> None:
"""Overridden by a subclass to test for conflicting factories."""
- return
- def unregister(self, class_):
+ def unregister(self, class_: Type[_O]) -> None:
manager = manager_of_class(class_)
manager.unregister()
self.dispatch.class_uninstrument(class_)
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 3e21b0102..3d093d367 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -60,6 +60,7 @@ from ..sql.base import ExecutableOption
from ..sql.cache_key import HasCacheKey
from ..sql.schema import Column
from ..sql.type_api import TypeEngine
+from ..util.typing import DescriptorReference
from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
@@ -68,7 +69,6 @@ if typing.TYPE_CHECKING:
from ._typing import _InstanceDict
from ._typing import _InternalEntityType
from ._typing import _ORMAdapterProto
- from ._typing import _ORMColumnExprArgument
from .attributes import InstrumentedAttribute
from .context import _MapperEntity
from .context import ORMCompileState
@@ -89,7 +89,6 @@ if typing.TYPE_CHECKING:
from ..sql._typing import _ColumnsClauseArgument
from ..sql._typing import _DMLColumnArgument
from ..sql._typing import _InfoType
- from ..sql._typing import _PropagateAttrsType
from ..sql.operators import OperatorType
from ..sql.util import ColumnAdapter
from ..sql.visitors import _TraverseInternalsType
@@ -171,12 +170,18 @@ class _MapsColumns(_MappedAttribute[_T]):
raise NotImplementedError()
+# NOTE: MapperProperty needs to extend _MappedAttribute so that declarative
+# typing works, i.e. "Mapped[A] = relationship()". This introduces an
+# inconvenience which is that all the MapperProperty objects are treated
+# as descriptors by typing tools, which are misled by this as assignment /
+# access to a descriptor attribute wants to move through __get__.
+# Therefore, references to MapperProperty as an instance variable, such
+# as in PropComparator, may have some special typing workarounds such as the
+# use of sqlalchemy.util.typing.DescriptorReference to avoid mis-interpretation
+# by typing tools
@inspection._self_inspects
class MapperProperty(
- HasCacheKey,
- _MappedAttribute[_T],
- InspectionAttrInfo,
- util.MemoizedSlots,
+ HasCacheKey, _MappedAttribute[_T], InspectionAttrInfo, util.MemoizedSlots
):
"""Represent a particular class attribute mapped by :class:`_orm.Mapper`.
@@ -522,6 +527,7 @@ class PropComparator(SQLORMOperations[_T]):
_parententity: _InternalEntityType[Any]
_adapt_to_entity: Optional[AliasedInsp[Any]]
+ prop: DescriptorReference[MapperProperty[_T]]
def __init__(
self,
@@ -533,11 +539,20 @@ class PropComparator(SQLORMOperations[_T]):
self._parententity = adapt_to_entity or parentmapper
self._adapt_to_entity = adapt_to_entity
- @util.ro_non_memoized_property
+ @util.non_memoized_property
def property(self) -> Optional[MapperProperty[_T]]:
+ """Return the :class:`.MapperProperty` associated with this
+ :class:`.PropComparator`.
+
+
+ Return values here will commonly be instances of
+ :class:`.ColumnProperty` or :class:`.Relationship`.
+
+
+ """
return self.prop
- def __clause_element__(self) -> _ORMColumnExprArgument[_T]:
+ def __clause_element__(self) -> roles.ColumnsClauseRole:
raise NotImplementedError("%r" % self)
def _bulk_update_tuples(
@@ -567,18 +582,6 @@ class PropComparator(SQLORMOperations[_T]):
compatible with QueryableAttribute."""
return self._parententity.mapper
- @util.memoized_property
- def _propagate_attrs(self) -> _PropagateAttrsType:
- # this suits the case in coercions where we don't actually
- # call ``__clause_element__()`` but still need to get
- # resolved._propagate_attrs. See #6558.
- return util.immutabledict(
- {
- "compile_state_plugin": "orm",
- "plugin_subject": self._parentmapper,
- }
- )
-
def _criterion_exists(
self,
criterion: Optional[_ColumnExpressionArgument[bool]] = None,
@@ -657,7 +660,7 @@ class PropComparator(SQLORMOperations[_T]):
def and_(
self, *criteria: _ColumnExpressionArgument[bool]
- ) -> ColumnElement[bool]:
+ ) -> PropComparator[bool]:
"""Add additional criteria to the ON clause that's represented by this
relationship attribute.
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 66021c9c2..514ad7023 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -71,6 +71,7 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from ._typing import _EntityType
+ from ._typing import _InternalEntityType
from .mapper import Mapper
from .util import AliasedClass
from .util import AliasedInsp
@@ -348,7 +349,7 @@ class Relationship(
doc=self.doc,
)
- class Comparator(PropComparator[_PT]):
+ class Comparator(util.MemoizedSlots, PropComparator[_PT]):
"""Produce boolean, comparison, and other operators for
:class:`.Relationship` attributes.
@@ -369,8 +370,13 @@ class Relationship(
"""
- _of_type = None
- _extra_criteria = ()
+ __slots__ = (
+ "entity",
+ "mapper",
+ "property",
+ "_of_type",
+ "_extra_criteria",
+ )
def __init__(
self,
@@ -389,6 +395,8 @@ class Relationship(
self._adapt_to_entity = adapt_to_entity
if of_type:
self._of_type = of_type
+ else:
+ self._of_type = None
self._extra_criteria = extra_criteria
def adapt_to_entity(self, adapt_to_entity):
@@ -399,40 +407,35 @@ class Relationship(
of_type=self._of_type,
)
- @util.memoized_property
- def entity(self):
- """The target entity referred to by this
- :class:`.Relationship.Comparator`.
+ entity: _InternalEntityType
+ """The target entity referred to by this
+ :class:`.Relationship.Comparator`.
- This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp`
- object.
+ This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp`
+ object.
- This is the "target" or "remote" side of the
- :func:`_orm.relationship`.
+ This is the "target" or "remote" side of the
+ :func:`_orm.relationship`.
- """
- # this is a relatively recent change made for
- # 1.4.27 as part of #7244.
- # TODO: shouldn't _of_type be inspected up front when received?
- if self._of_type is not None:
- return inspect(self._of_type)
- else:
- return self.property.entity
+ """
- @util.memoized_property
- def mapper(self):
- """The target :class:`_orm.Mapper` referred to by this
- :class:`.Relationship.Comparator`.
+ mapper: Mapper[Any]
+ """The target :class:`_orm.Mapper` referred to by this
+ :class:`.Relationship.Comparator`.
- This is the "target" or "remote" side of the
- :func:`_orm.relationship`.
+ This is the "target" or "remote" side of the
+ :func:`_orm.relationship`.
- """
- return self.property.mapper
+ """
- @util.memoized_property
- def _parententity(self):
- return self.property.parent
+ def _memoized_attr_entity(self) -> _InternalEntityType:
+ if self._of_type:
+ return inspect(self._of_type)
+ else:
+ return self.prop.entity
+
+ def _memoized_attr_mapper(self) -> Mapper[Any]:
+ return self.entity.mapper
def _source_selectable(self):
if self._adapt_to_entity:
@@ -481,7 +484,9 @@ class Relationship(
extra_criteria=self._extra_criteria,
)
- def and_(self, *other):
+ def and_(
+ self, *criteria: _ColumnExpressionArgument[bool]
+ ) -> interfaces.PropComparator[bool]:
"""Add AND criteria.
See :meth:`.PropComparator.and_` for an example.
@@ -489,12 +494,17 @@ class Relationship(
.. versionadded:: 1.4
"""
+ exprs = tuple(
+ coercions.expect(roles.WhereHavingRole, clause)
+ for clause in util.coerce_generator_arg(criteria)
+ )
+
return Relationship.Comparator(
self.property,
self._parententity,
adapt_to_entity=self._adapt_to_entity,
of_type=self._of_type,
- extra_criteria=self._extra_criteria + other,
+ extra_criteria=self._extra_criteria + exprs,
)
def in_(self, other):
@@ -924,8 +934,7 @@ class Relationship(
else:
return _orm_annotate(self.__negated_contains_or_equals(other))
- @util.memoized_property
- def property(self):
+ def _memoized_attr_property(self):
self.prop.parent._check_configure()
return self.prop
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index ab32a3981..49ee701b4 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -23,6 +23,7 @@ from typing import Optional
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
+from typing import Union
import weakref
from . import base
@@ -43,6 +44,7 @@ from .path_registry import PathRegistry
from .. import exc as sa_exc
from .. import inspection
from .. import util
+from ..util.typing import Literal
from ..util.typing import Protocol
if TYPE_CHECKING:
@@ -53,6 +55,7 @@ if TYPE_CHECKING:
from .attributes import History
from .base import LoaderCallableStatus
from .base import PassiveFlag
+ from .collections import _AdaptedCollectionProtocol
from .identity import IdentityMap
from .instrumentation import ClassManager
from .interfaces import ORMOption
@@ -421,7 +424,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
return self.key
@util.memoized_property
- def parents(self) -> Dict[int, InstanceState[Any]]:
+ def parents(self) -> Dict[int, Union[Literal[False], InstanceState[Any]]]:
return {}
@util.memoized_property
@@ -429,7 +432,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
return {}
@util.memoized_property
- def _empty_collections(self) -> Dict[Any, Any]:
+ def _empty_collections(self) -> Dict[str, _AdaptedCollectionProtocol]:
return {}
@util.memoized_property
@@ -844,7 +847,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
def _modified_event(
self,
dict_: _InstanceDict,
- attr: AttributeImpl,
+ attr: Optional[AttributeImpl],
previous: Any,
collection: bool = False,
is_userland: bool = False,
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 7e8a6b4c6..b095e3f7a 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -24,6 +24,7 @@ from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import weakref
@@ -531,6 +532,8 @@ class AliasedClass(
"""
+ __name__: str
+
def __init__(
self,
mapped_class_or_ac: _EntityType[_O],
@@ -1529,7 +1532,7 @@ class _ORMJoin(expression.Join):
full: bool = False,
_left_memo: Optional[Any] = None,
_right_memo: Optional[Any] = None,
- _extra_criteria: Sequence[ColumnElement[bool]] = (),
+ _extra_criteria: Tuple[ColumnElement[bool], ...] = (),
):
left_info = cast(
"Union[FromClause, _InternalEntityType[Any]]",
@@ -1547,6 +1550,8 @@ class _ORMJoin(expression.Join):
self._right_memo = _right_memo
if isinstance(onclause, attributes.QueryableAttribute):
+ if TYPE_CHECKING:
+ assert isinstance(onclause.comparator, Relationship.Comparator)
on_selectable = onclause.comparator._source_selectable()
prop = onclause.property
_extra_criteria += onclause._extra_criteria
@@ -1728,12 +1733,15 @@ def with_parent(
elif isinstance(prop, attributes.QueryableAttribute):
if prop._of_type:
from_entity = prop._of_type
- if not prop_is_relationship(prop.property):
+ mapper_property = prop.property
+ if mapper_property is None or not prop_is_relationship(
+ mapper_property
+ ):
raise sa_exc.ArgumentError(
f"Expected relationship property for with_parent(), "
- f"got {prop.property}"
+ f"got {mapper_property}"
)
- prop_t = prop.property
+ prop_t = mapper_property
else:
prop_t = prop
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index fb959654f..248b48a25 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -133,6 +133,8 @@ class Immutable:
"""
+ __slots__ = ()
+
_is_immutable = True
def unique_params(self, *optionaldict, **kwargs):
@@ -145,7 +147,7 @@ class Immutable:
return self
def _copy_internals(
- self, omit_attrs: Iterable[str] = (), **kw: Any
+ self, *, omit_attrs: Iterable[str] = (), **kw: Any
) -> None:
pass
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index 15fbc2afb..c16fbdae1 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -36,7 +36,6 @@ if typing.TYPE_CHECKING:
from .elements import BindParameter
from .elements import ClauseElement
from .visitors import _TraverseInternalsType
- from ..engine.base import _CompiledCacheType
from ..engine.interfaces import _CoreSingleExecuteParams
@@ -393,6 +392,13 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
return HasCacheKey._generate_cache_key(self)
+class SlotsMemoizedHasCacheKey(HasCacheKey, util.MemoizedSlots):
+ __slots__ = ()
+
+ def _memoized_method__generate_cache_key(self) -> Optional[CacheKey]:
+ return HasCacheKey._generate_cache_key(self)
+
+
class CacheKey(NamedTuple):
"""The key used to identify a SQL statement construct in the
SQL compilation cache.
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 2e0112f08..2655adbdc 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -1265,7 +1265,16 @@ class ColumnAdapter(ClauseAdapter):
if self.adapt_required and c is col:
return None
- c._allow_label_resolve = self.allow_label_resolve
+ # allow_label_resolve is consumed by one case for joined eager loading
+ # as part of its logic to prevent its own columns from being affected
+ # by .order_by(). Before full typing were applied to the ORM, this
+ # logic would set this attribute on the incoming object (which is
+ # typically a column, but we have a test for it being a non-column
+ # object) if no column were found. While this seemed to
+ # have no negative effects, this adjustment should only occur on the
+ # new column which is assumed to be local to an adapted selectable.
+ if c is not col:
+ c._allow_label_resolve = self.allow_label_resolve
return c
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index e9b0c93f2..7150dedcf 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -410,11 +410,11 @@ class UniqueAppender(Generic[_T]):
return iter(self.data)
-def coerce_generator_arg(arg):
+def coerce_generator_arg(arg: Any) -> List[Any]:
if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
return list(arg[0])
else:
- return arg
+ return cast("List[Any]", arg)
def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 10110dbbe..24c66bfa4 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -1272,17 +1272,20 @@ class MemoizedSlots:
def _fallback_getattr(self, key):
raise AttributeError(key)
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
if key.startswith("_memoized_attr_") or key.startswith(
"_memoized_method_"
):
raise AttributeError(key)
- elif hasattr(self, "_memoized_attr_%s" % key):
- value = getattr(self, "_memoized_attr_%s" % key)()
+ # to avoid recursion errors when interacting with other __getattr__
+ # schemes that refer to this one, when testing for memoized method
+ # look at __class__ only rather than going into __getattr__ again.
+ elif hasattr(self.__class__, f"_memoized_attr_{key}"):
+ value = getattr(self, f"_memoized_attr_{key}")()
setattr(self, key, value)
return value
- elif hasattr(self, "_memoized_method_%s" % key):
- fn = getattr(self, "_memoized_method_%s" % key)
+ elif hasattr(self.__class__, f"_memoized_method_{key}"):
+ fn = getattr(self, f"_memoized_method_{key}")
def oneshot(*args, **kw):
result = fn(*args, **kw)
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 4929ba1a6..a95f5ab93 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -9,6 +9,7 @@ from typing import Callable
from typing import cast
from typing import Dict
from typing import ForwardRef
+from typing import Generic
from typing import Iterable
from typing import Optional
from typing import Tuple
@@ -213,3 +214,41 @@ def _get_type_name(type_: Type[Any]) -> str:
typ_name = getattr(type_, "_name", None)
return typ_name # type: ignore
+
+
+class DescriptorProto(Protocol):
+ def __get__(self, instance: object, owner: Any) -> Any:
+ ...
+
+ def __set__(self, instance: Any, value: Any) -> None:
+ ...
+
+ def __delete__(self, instance: Any) -> None:
+ ...
+
+
+_DESC = TypeVar("_DESC", bound=DescriptorProto)
+
+
+class DescriptorReference(Generic[_DESC]):
+ """a descriptor that refers to a descriptor.
+
+ used for cases where we need to have an instance variable referring to an
+ object that is itself a descriptor, which typically confuses typing tools
+ as they don't know when they should use ``__get__`` or not when referring
+ to the descriptor assignment as an instance variable. See
+ sqlalchemy.orm.interfaces.PropComparator.prop
+
+ """
+
+ def __get__(self, instance: object, owner: Any) -> _DESC:
+ ...
+
+ def __set__(self, instance: Any, value: _DESC) -> None:
+ ...
+
+ def __delete__(self, instance: Any) -> None:
+ ...
+
+
+# $def ro_descriptor_reference(fn: Callable[])