diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-28 16:19:43 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-03 15:58:45 -0400 |
commit | 1fa3e2e3814b4d28deca7426bb3f36e7fb515496 (patch) | |
tree | 9b07b8437b1190227c2e8c51f2e942936721000f /lib/sqlalchemy/orm/instrumentation.py | |
parent | 6a496a5f40efe6d58b09eeca9320829789ceaa54 (diff) | |
download | sqlalchemy-1fa3e2e3814b4d28deca7426bb3f36e7fb515496.tar.gz |
pep484: attributes and related
also implements __slots__ for QueryableAttribute,
InstrumentedAttribute, Relationship.Comparator.
Change-Id: I47e823160706fc35a616f1179a06c7864089e5b5
Diffstat (limited to 'lib/sqlalchemy/orm/instrumentation.py')
-rw-r--r-- | lib/sqlalchemy/orm/instrumentation.py | 121 |
1 files changed, 77 insertions, 44 deletions
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 356958562..85b85215e 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -4,7 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +# mypy: allow-untyped-defs, allow-untyped-calls """Defines SQLAlchemy's system of class instrumentation. @@ -35,14 +35,19 @@ from __future__ import annotations from typing import Any from typing import Callable +from typing import cast +from typing import Collection from typing import Dict from typing import Generic +from typing import Iterable +from typing import List from typing import Optional from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union import weakref from . import base @@ -51,15 +56,21 @@ from . import exc from . import interfaces from . import state from ._typing import _O +from .attributes import _is_collection_attribute_impl from .. import util from ..event import EventTarget from ..util import HasMemoized +from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: from ._typing import _RegistryType + from .attributes import AttributeImpl from .attributes import InstrumentedAttribute + from .collections import _AdaptedCollectionProtocol + from .collections import _CollectionFactoryType from .decl_base import _MapperConfig + from .events import InstanceEvents from .mapper import Mapper from .state import InstanceState from ..event import dispatcher @@ -74,7 +85,7 @@ class _ExpiredAttributeLoaderProto(Protocol): state: state.InstanceState[Any], toload: Set[str], passive: base.PassiveFlag, - ): + ) -> None: ... @@ -91,7 +102,7 @@ class ClassManager( ): """Tracks state information at the class level.""" - dispatch: dispatcher[ClassManager] + dispatch: dispatcher[ClassManager[_O]] MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR STATE_ATTR = base.DEFAULT_STATE_ATTR @@ -108,8 +119,9 @@ class ClassManager( declarative_scan: Optional[weakref.ref[_MapperConfig]] = None registry: Optional[_RegistryType] = None - @property - @util.deprecated( + _bases: List[ClassManager[Any]] + + @util.deprecated_property( "1.4", message="The ClassManager.deferred_scalar_loader attribute is now " "named expired_attribute_loader", @@ -117,7 +129,7 @@ class ClassManager( def deferred_scalar_loader(self): return self.expired_attribute_loader - @deferred_scalar_loader.setter + @deferred_scalar_loader.setter # type: ignore[no-redef] @util.deprecated( "1.4", message="The ClassManager.deferred_scalar_loader attribute is now " @@ -138,18 +150,23 @@ class ClassManager( self._bases = [ mgr - for mgr in [ - opt_manager_of_class(base) - for base in self.class_.__bases__ - if isinstance(base, type) - ] + for mgr in cast( + "List[Optional[ClassManager[Any]]]", + [ + opt_manager_of_class(base) + for base in self.class_.__bases__ + if isinstance(base, type) + ], + ) if mgr is not None ] for base_ in self._bases: self.update(base_) - self.dispatch._events._new_classmanager_instance(class_, self) + cast( + "InstanceEvents", self.dispatch._events + )._new_classmanager_instance(class_, self) for basecls in class_.__mro__: mgr = opt_manager_of_class(basecls) @@ -263,7 +280,7 @@ class ClassManager( """ - found = {} + found: Dict[str, Any] = {} # constraints: # 1. yield keys in cls.__dict__ order @@ -303,7 +320,7 @@ class ClassManager( return key in self and self[key].impl is not None - def _subclass_manager(self, cls): + def _subclass_manager(self, cls: Type[_T]) -> ClassManager[_T]: """Create a new ClassManager for a subclass of this ClassManager's class. @@ -321,7 +338,7 @@ class ClassManager( self.install_member("__init__", self.new_init) @util.memoized_property - def _state_constructor(self): + def _state_constructor(self) -> Type[state.InstanceState[_O]]: self.dispatch.first_init(self, self.class_) return state.InstanceState @@ -393,13 +410,15 @@ class ClassManager( if manager: manager.uninstrument_attribute(key, True) - def unregister(self): + def unregister(self) -> None: """remove all instrumentation established by this ClassManager.""" for key in list(self.originals): self.uninstall_member(key) - self.mapper = self.dispatch = self.new_init = None + self.mapper = None # type: ignore + self.dispatch = None # type: ignore + self.new_init = None self.info.clear() for key in list(self): @@ -409,7 +428,9 @@ class ClassManager( if self.MANAGER_ATTR in self.class_.__dict__: delattr(self.class_, self.MANAGER_ATTR) - def install_descriptor(self, key, inst): + def install_descriptor( + self, key: str, inst: InstrumentedAttribute[Any] + ) -> None: if key in (self.STATE_ATTR, self.MANAGER_ATTR): raise KeyError( "%r: requested attribute name conflicts with " @@ -417,10 +438,10 @@ class ClassManager( ) setattr(self.class_, key, inst) - def uninstall_descriptor(self, key): + def uninstall_descriptor(self, key: str) -> None: delattr(self.class_, key) - def install_member(self, key, implementation): + def install_member(self, key: str, implementation: Any) -> None: if key in (self.STATE_ATTR, self.MANAGER_ATTR): raise KeyError( "%r: requested attribute name conflicts with " @@ -429,34 +450,41 @@ class ClassManager( self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR)) setattr(self.class_, key, implementation) - def uninstall_member(self, key): + def uninstall_member(self, key: str) -> None: original = self.originals.pop(key, None) if original is not DEL_ATTR: setattr(self.class_, key, original) else: delattr(self.class_, key) - def instrument_collection_class(self, key, collection_class): + def instrument_collection_class( + self, key: str, collection_class: Type[Collection[Any]] + ) -> _CollectionFactoryType: return collections.prepare_instrumentation(collection_class) - def initialize_collection(self, key, state, factory): + def initialize_collection( + self, + key: str, + state: InstanceState[_O], + factory: _CollectionFactoryType, + ) -> Tuple[collections.CollectionAdapter, _AdaptedCollectionProtocol]: user_data = factory() - adapter = collections.CollectionAdapter( - self.get_impl(key), state, user_data - ) + impl = self.get_impl(key) + assert _is_collection_attribute_impl(impl) + adapter = collections.CollectionAdapter(impl, state, user_data) return adapter, user_data - def is_instrumented(self, key, search=False): + def is_instrumented(self, key: str, search: bool = False) -> bool: if search: return key in self else: return key in self.local_attrs - def get_impl(self, key): + def get_impl(self, key: str) -> AttributeImpl: return self[key].impl @property - def attributes(self): + def attributes(self) -> Iterable[Any]: return iter(self.values()) # InstanceState management @@ -466,22 +494,26 @@ class ClassManager( if state is None: state = self._state_constructor(instance, self) self._state_setter(instance, state) - return instance + return instance # type: ignore[no-any-return] - def setup_instance(self, instance, state=None): + def setup_instance( + self, instance: _O, state: Optional[InstanceState[_O]] = None + ) -> None: if state is None: state = self._state_constructor(instance, self) self._state_setter(instance, state) - def teardown_instance(self, instance): + def teardown_instance(self, instance: _O) -> None: delattr(instance, self.STATE_ATTR) def _serialize( - self, state: state.InstanceState, state_dict: Dict[str, Any] + self, state: InstanceState[_O], state_dict: Dict[str, Any] ) -> _SerializeManager: return _SerializeManager(state, state_dict) - def _new_state_if_none(self, instance): + def _new_state_if_none( + self, instance: _O + ) -> Union[Literal[False], InstanceState[_O]]: """Install a default InstanceState if none is present. A private convenience method used by the __init__ decorator. @@ -503,20 +535,20 @@ class ClassManager( self._state_setter(instance, state) return state - def has_state(self, instance): + def has_state(self, instance: _O) -> bool: return hasattr(instance, self.STATE_ATTR) - def has_parent(self, state, key, optimistic=False): + def has_parent( + self, state: InstanceState[_O], key: str, optimistic: bool = False + ) -> bool: """TODO""" return self.get_impl(key).hasparent(state, optimistic=optimistic) - def __bool__(self): + def __bool__(self) -> bool: """All ClassManagers are non-zero regardless of attribute state.""" return True - __nonzero__ = __bool__ - - def __repr__(self): + def __repr__(self) -> str: return "<%s of %r at %x>" % ( self.__class__.__name__, self.class_, @@ -558,9 +590,11 @@ class _SerializeManager: manager.dispatch.unpickle(state, state_dict) -class InstrumentationFactory: +class InstrumentationFactory(EventTarget): """Factory for new ClassManager instances.""" + dispatch: dispatcher[InstrumentationFactory] + def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]: assert class_ is not None assert opt_manager_of_class(class_) is None @@ -589,11 +623,10 @@ class InstrumentationFactory: def _check_conflicts( self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]] - ): + ) -> None: """Overridden by a subclass to test for conflicting factories.""" - return - def unregister(self, class_): + def unregister(self, class_: Type[_O]) -> None: manager = manager_of_class(class_) manager.unregister() self.dispatch.class_uninstrument(class_) |