diff options
Diffstat (limited to 'lib/sqlalchemy/sql/traversals.py')
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 64 |
1 files changed, 37 insertions, 27 deletions
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 4fa23d370..cf9487f93 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -15,7 +15,10 @@ import operator import typing from typing import Any from typing import Callable +from typing import Deque from typing import Dict +from typing import Set +from typing import Tuple from typing import Type from typing import TypeVar @@ -23,9 +26,9 @@ from . import operators from .cache_key import HasCacheKey from .visitors import _TraverseInternalsType from .visitors import anon_map -from .visitors import ExtendedInternalTraversal +from .visitors import ExternallyTraversible +from .visitors import HasTraversalDispatch from .visitors import HasTraverseInternals -from .visitors import InternalTraversal from .. import util from ..util import langhelpers @@ -35,6 +38,7 @@ COMPARE_SUCCEEDED = True def compare(obj1, obj2, **kw): + strategy: TraversalComparatorStrategy if kw.get("use_proxies", False): strategy = ColIdentityComparatorStrategy() else: @@ -45,16 +49,18 @@ def compare(obj1, obj2, **kw): def _preconfigure_traversals(target_hierarchy): for cls in util.walk_subclasses(target_hierarchy): - if hasattr(cls, "_traverse_internals"): - cls._generate_cache_attrs() + if hasattr(cls, "_generate_cache_attrs") and hasattr( + cls, "_traverse_internals" + ): + cls._generate_cache_attrs() # type: ignore _copy_internals.generate_dispatch( - cls, - cls._traverse_internals, + cls, # type: ignore + cls._traverse_internals, # type: ignore "_generated_copy_internals_traversal", ) _get_children.generate_dispatch( - cls, - cls._traverse_internals, + cls, # type: ignore + cls._traverse_internals, # type: ignore "_generated_get_children_traversal", ) @@ -125,54 +131,58 @@ class HasShallowCopy(HasTraverseInternals): meth_text = f"def {method_name}(self, d):\n{code}\n" return langhelpers._exec_code_in_env(meth_text, {}, method_name) - def _shallow_from_dict(self, d: Dict) -> None: + def _shallow_from_dict(self, d: Dict[str, Any]) -> None: cls = self.__class__ + shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None] try: shallow_from_dict = cls.__dict__[ "_generated_shallow_from_dict_traversal" ] except KeyError: - shallow_from_dict = ( - cls._generated_shallow_from_dict_traversal # type: ignore - ) = self._generate_shallow_from_dict( + shallow_from_dict = self._generate_shallow_from_dict( cls._traverse_internals, "_generated_shallow_from_dict_traversal", ) + cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa E501 + shallow_from_dict(self, d) def _shallow_to_dict(self) -> Dict[str, Any]: cls = self.__class__ + shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]] + try: shallow_to_dict = cls.__dict__[ "_generated_shallow_to_dict_traversal" ] except KeyError: - shallow_to_dict = ( - cls._generated_shallow_to_dict_traversal # type: ignore - ) = self._generate_shallow_to_dict( + shallow_to_dict = self._generate_shallow_to_dict( cls._traverse_internals, "_generated_shallow_to_dict_traversal" ) + cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa E501 return shallow_to_dict(self) - def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy): + def _shallow_copy_to( + self: SelfHasShallowCopy, other: SelfHasShallowCopy + ) -> None: cls = self.__class__ + shallow_copy: Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None] try: shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"] except KeyError: - shallow_copy = ( - cls._generated_shallow_copy_traversal # type: ignore - ) = self._generate_shallow_copy( + shallow_copy = self._generate_shallow_copy( cls._traverse_internals, "_generated_shallow_copy_traversal" ) + cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501 shallow_copy(self, other) - def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy: + def _clone(self: SelfHasShallowCopy, **kw: Any) -> SelfHasShallowCopy: """Create a shallow copy""" c = self.__class__.__new__(self.__class__) self._shallow_copy_to(c) @@ -246,7 +256,7 @@ class HasCopyInternals(HasTraverseInternals): setattr(self, attrname, result) -class _CopyInternalsTraversal(InternalTraversal): +class _CopyInternalsTraversal(HasTraversalDispatch): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -381,7 +391,7 @@ def _flatten_clauseelement(element): return element -class _GetChildrenTraversal(InternalTraversal): +class _GetChildrenTraversal(HasTraversalDispatch): """Generate a _children_traversal internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -463,13 +473,13 @@ def _resolve_name_for_compare(element, name, anon_map, **kw): return name -class TraversalComparatorStrategy( - ExtendedInternalTraversal, util.MemoizedSlots -): +class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): __slots__ = "stack", "cache", "anon_map" def __init__(self): - self.stack = deque() + self.stack: Deque[ + Tuple[ExternallyTraversible, ExternallyTraversible] + ] = deque() self.cache = set() def _memoized_attr_anon_map(self): @@ -653,7 +663,7 @@ class TraversalComparatorStrategy( if seq1 is None: return seq2 is None - completed = set() + completed: Set[object] = set() for clause in seq1: for other_clause in set(seq2).difference(completed): if self.compare_inner(clause, other_clause, **kw): |