summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/instrumentation.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-28 16:19:43 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-05-03 15:58:45 -0400
commit1fa3e2e3814b4d28deca7426bb3f36e7fb515496 (patch)
tree9b07b8437b1190227c2e8c51f2e942936721000f /lib/sqlalchemy/orm/instrumentation.py
parent6a496a5f40efe6d58b09eeca9320829789ceaa54 (diff)
downloadsqlalchemy-1fa3e2e3814b4d28deca7426bb3f36e7fb515496.tar.gz
pep484: attributes and related
also implements __slots__ for QueryableAttribute, InstrumentedAttribute, Relationship.Comparator. Change-Id: I47e823160706fc35a616f1179a06c7864089e5b5
Diffstat (limited to 'lib/sqlalchemy/orm/instrumentation.py')
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py121
1 files changed, 77 insertions, 44 deletions
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
index 356958562..85b85215e 100644
--- a/lib/sqlalchemy/orm/instrumentation.py
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -4,7 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
"""Defines SQLAlchemy's system of class instrumentation.
@@ -35,14 +35,19 @@ from __future__ import annotations
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Collection
from typing import Dict
from typing import Generic
+from typing import Iterable
+from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
import weakref
from . import base
@@ -51,15 +56,21 @@ from . import exc
from . import interfaces
from . import state
from ._typing import _O
+from .attributes import _is_collection_attribute_impl
from .. import util
from ..event import EventTarget
from ..util import HasMemoized
+from ..util.typing import Literal
from ..util.typing import Protocol
if TYPE_CHECKING:
from ._typing import _RegistryType
+ from .attributes import AttributeImpl
from .attributes import InstrumentedAttribute
+ from .collections import _AdaptedCollectionProtocol
+ from .collections import _CollectionFactoryType
from .decl_base import _MapperConfig
+ from .events import InstanceEvents
from .mapper import Mapper
from .state import InstanceState
from ..event import dispatcher
@@ -74,7 +85,7 @@ class _ExpiredAttributeLoaderProto(Protocol):
state: state.InstanceState[Any],
toload: Set[str],
passive: base.PassiveFlag,
- ):
+ ) -> None:
...
@@ -91,7 +102,7 @@ class ClassManager(
):
"""Tracks state information at the class level."""
- dispatch: dispatcher[ClassManager]
+ dispatch: dispatcher[ClassManager[_O]]
MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
STATE_ATTR = base.DEFAULT_STATE_ATTR
@@ -108,8 +119,9 @@ class ClassManager(
declarative_scan: Optional[weakref.ref[_MapperConfig]] = None
registry: Optional[_RegistryType] = None
- @property
- @util.deprecated(
+ _bases: List[ClassManager[Any]]
+
+ @util.deprecated_property(
"1.4",
message="The ClassManager.deferred_scalar_loader attribute is now "
"named expired_attribute_loader",
@@ -117,7 +129,7 @@ class ClassManager(
def deferred_scalar_loader(self):
return self.expired_attribute_loader
- @deferred_scalar_loader.setter
+ @deferred_scalar_loader.setter # type: ignore[no-redef]
@util.deprecated(
"1.4",
message="The ClassManager.deferred_scalar_loader attribute is now "
@@ -138,18 +150,23 @@ class ClassManager(
self._bases = [
mgr
- for mgr in [
- opt_manager_of_class(base)
- for base in self.class_.__bases__
- if isinstance(base, type)
- ]
+ for mgr in cast(
+ "List[Optional[ClassManager[Any]]]",
+ [
+ opt_manager_of_class(base)
+ for base in self.class_.__bases__
+ if isinstance(base, type)
+ ],
+ )
if mgr is not None
]
for base_ in self._bases:
self.update(base_)
- self.dispatch._events._new_classmanager_instance(class_, self)
+ cast(
+ "InstanceEvents", self.dispatch._events
+ )._new_classmanager_instance(class_, self)
for basecls in class_.__mro__:
mgr = opt_manager_of_class(basecls)
@@ -263,7 +280,7 @@ class ClassManager(
"""
- found = {}
+ found: Dict[str, Any] = {}
# constraints:
# 1. yield keys in cls.__dict__ order
@@ -303,7 +320,7 @@ class ClassManager(
return key in self and self[key].impl is not None
- def _subclass_manager(self, cls):
+ def _subclass_manager(self, cls: Type[_T]) -> ClassManager[_T]:
"""Create a new ClassManager for a subclass of this ClassManager's
class.
@@ -321,7 +338,7 @@ class ClassManager(
self.install_member("__init__", self.new_init)
@util.memoized_property
- def _state_constructor(self):
+ def _state_constructor(self) -> Type[state.InstanceState[_O]]:
self.dispatch.first_init(self, self.class_)
return state.InstanceState
@@ -393,13 +410,15 @@ class ClassManager(
if manager:
manager.uninstrument_attribute(key, True)
- def unregister(self):
+ def unregister(self) -> None:
"""remove all instrumentation established by this ClassManager."""
for key in list(self.originals):
self.uninstall_member(key)
- self.mapper = self.dispatch = self.new_init = None
+ self.mapper = None # type: ignore
+ self.dispatch = None # type: ignore
+ self.new_init = None
self.info.clear()
for key in list(self):
@@ -409,7 +428,9 @@ class ClassManager(
if self.MANAGER_ATTR in self.class_.__dict__:
delattr(self.class_, self.MANAGER_ATTR)
- def install_descriptor(self, key, inst):
+ def install_descriptor(
+ self, key: str, inst: InstrumentedAttribute[Any]
+ ) -> None:
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
raise KeyError(
"%r: requested attribute name conflicts with "
@@ -417,10 +438,10 @@ class ClassManager(
)
setattr(self.class_, key, inst)
- def uninstall_descriptor(self, key):
+ def uninstall_descriptor(self, key: str) -> None:
delattr(self.class_, key)
- def install_member(self, key, implementation):
+ def install_member(self, key: str, implementation: Any) -> None:
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
raise KeyError(
"%r: requested attribute name conflicts with "
@@ -429,34 +450,41 @@ class ClassManager(
self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR))
setattr(self.class_, key, implementation)
- def uninstall_member(self, key):
+ def uninstall_member(self, key: str) -> None:
original = self.originals.pop(key, None)
if original is not DEL_ATTR:
setattr(self.class_, key, original)
else:
delattr(self.class_, key)
- def instrument_collection_class(self, key, collection_class):
+ def instrument_collection_class(
+ self, key: str, collection_class: Type[Collection[Any]]
+ ) -> _CollectionFactoryType:
return collections.prepare_instrumentation(collection_class)
- def initialize_collection(self, key, state, factory):
+ def initialize_collection(
+ self,
+ key: str,
+ state: InstanceState[_O],
+ factory: _CollectionFactoryType,
+ ) -> Tuple[collections.CollectionAdapter, _AdaptedCollectionProtocol]:
user_data = factory()
- adapter = collections.CollectionAdapter(
- self.get_impl(key), state, user_data
- )
+ impl = self.get_impl(key)
+ assert _is_collection_attribute_impl(impl)
+ adapter = collections.CollectionAdapter(impl, state, user_data)
return adapter, user_data
- def is_instrumented(self, key, search=False):
+ def is_instrumented(self, key: str, search: bool = False) -> bool:
if search:
return key in self
else:
return key in self.local_attrs
- def get_impl(self, key):
+ def get_impl(self, key: str) -> AttributeImpl:
return self[key].impl
@property
- def attributes(self):
+ def attributes(self) -> Iterable[Any]:
return iter(self.values())
# InstanceState management
@@ -466,22 +494,26 @@ class ClassManager(
if state is None:
state = self._state_constructor(instance, self)
self._state_setter(instance, state)
- return instance
+ return instance # type: ignore[no-any-return]
- def setup_instance(self, instance, state=None):
+ def setup_instance(
+ self, instance: _O, state: Optional[InstanceState[_O]] = None
+ ) -> None:
if state is None:
state = self._state_constructor(instance, self)
self._state_setter(instance, state)
- def teardown_instance(self, instance):
+ def teardown_instance(self, instance: _O) -> None:
delattr(instance, self.STATE_ATTR)
def _serialize(
- self, state: state.InstanceState, state_dict: Dict[str, Any]
+ self, state: InstanceState[_O], state_dict: Dict[str, Any]
) -> _SerializeManager:
return _SerializeManager(state, state_dict)
- def _new_state_if_none(self, instance):
+ def _new_state_if_none(
+ self, instance: _O
+ ) -> Union[Literal[False], InstanceState[_O]]:
"""Install a default InstanceState if none is present.
A private convenience method used by the __init__ decorator.
@@ -503,20 +535,20 @@ class ClassManager(
self._state_setter(instance, state)
return state
- def has_state(self, instance):
+ def has_state(self, instance: _O) -> bool:
return hasattr(instance, self.STATE_ATTR)
- def has_parent(self, state, key, optimistic=False):
+ def has_parent(
+ self, state: InstanceState[_O], key: str, optimistic: bool = False
+ ) -> bool:
"""TODO"""
return self.get_impl(key).hasparent(state, optimistic=optimistic)
- def __bool__(self):
+ def __bool__(self) -> bool:
"""All ClassManagers are non-zero regardless of attribute state."""
return True
- __nonzero__ = __bool__
-
- def __repr__(self):
+ def __repr__(self) -> str:
return "<%s of %r at %x>" % (
self.__class__.__name__,
self.class_,
@@ -558,9 +590,11 @@ class _SerializeManager:
manager.dispatch.unpickle(state, state_dict)
-class InstrumentationFactory:
+class InstrumentationFactory(EventTarget):
"""Factory for new ClassManager instances."""
+ dispatch: dispatcher[InstrumentationFactory]
+
def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]:
assert class_ is not None
assert opt_manager_of_class(class_) is None
@@ -589,11 +623,10 @@ class InstrumentationFactory:
def _check_conflicts(
self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]]
- ):
+ ) -> None:
"""Overridden by a subclass to test for conflicting factories."""
- return
- def unregister(self, class_):
+ def unregister(self, class_: Type[_O]) -> None:
manager = manager_of_class(class_)
manager.unregister()
self.dispatch.class_uninstrument(class_)