summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/traversals.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/traversals.py')
-rw-r--r--lib/sqlalchemy/sql/traversals.py64
1 files changed, 37 insertions, 27 deletions
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 4fa23d370..cf9487f93 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -15,7 +15,10 @@ import operator
import typing
from typing import Any
from typing import Callable
+from typing import Deque
from typing import Dict
+from typing import Set
+from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -23,9 +26,9 @@ from . import operators
from .cache_key import HasCacheKey
from .visitors import _TraverseInternalsType
from .visitors import anon_map
-from .visitors import ExtendedInternalTraversal
+from .visitors import ExternallyTraversible
+from .visitors import HasTraversalDispatch
from .visitors import HasTraverseInternals
-from .visitors import InternalTraversal
from .. import util
from ..util import langhelpers
@@ -35,6 +38,7 @@ COMPARE_SUCCEEDED = True
def compare(obj1, obj2, **kw):
+ strategy: TraversalComparatorStrategy
if kw.get("use_proxies", False):
strategy = ColIdentityComparatorStrategy()
else:
@@ -45,16 +49,18 @@ def compare(obj1, obj2, **kw):
def _preconfigure_traversals(target_hierarchy):
for cls in util.walk_subclasses(target_hierarchy):
- if hasattr(cls, "_traverse_internals"):
- cls._generate_cache_attrs()
+ if hasattr(cls, "_generate_cache_attrs") and hasattr(
+ cls, "_traverse_internals"
+ ):
+ cls._generate_cache_attrs() # type: ignore
_copy_internals.generate_dispatch(
- cls,
- cls._traverse_internals,
+ cls, # type: ignore
+ cls._traverse_internals, # type: ignore
"_generated_copy_internals_traversal",
)
_get_children.generate_dispatch(
- cls,
- cls._traverse_internals,
+ cls, # type: ignore
+ cls._traverse_internals, # type: ignore
"_generated_get_children_traversal",
)
@@ -125,54 +131,58 @@ class HasShallowCopy(HasTraverseInternals):
meth_text = f"def {method_name}(self, d):\n{code}\n"
return langhelpers._exec_code_in_env(meth_text, {}, method_name)
- def _shallow_from_dict(self, d: Dict) -> None:
+ def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
cls = self.__class__
+ shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
try:
shallow_from_dict = cls.__dict__[
"_generated_shallow_from_dict_traversal"
]
except KeyError:
- shallow_from_dict = (
- cls._generated_shallow_from_dict_traversal # type: ignore
- ) = self._generate_shallow_from_dict(
+ shallow_from_dict = self._generate_shallow_from_dict(
cls._traverse_internals,
"_generated_shallow_from_dict_traversal",
)
+ cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa E501
+
shallow_from_dict(self, d)
def _shallow_to_dict(self) -> Dict[str, Any]:
cls = self.__class__
+ shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
+
try:
shallow_to_dict = cls.__dict__[
"_generated_shallow_to_dict_traversal"
]
except KeyError:
- shallow_to_dict = (
- cls._generated_shallow_to_dict_traversal # type: ignore
- ) = self._generate_shallow_to_dict(
+ shallow_to_dict = self._generate_shallow_to_dict(
cls._traverse_internals, "_generated_shallow_to_dict_traversal"
)
+ cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa E501
return shallow_to_dict(self)
- def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy):
+ def _shallow_copy_to(
+ self: SelfHasShallowCopy, other: SelfHasShallowCopy
+ ) -> None:
cls = self.__class__
+ shallow_copy: Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None]
try:
shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
except KeyError:
- shallow_copy = (
- cls._generated_shallow_copy_traversal # type: ignore
- ) = self._generate_shallow_copy(
+ shallow_copy = self._generate_shallow_copy(
cls._traverse_internals, "_generated_shallow_copy_traversal"
)
+ cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501
shallow_copy(self, other)
- def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy:
+ def _clone(self: SelfHasShallowCopy, **kw: Any) -> SelfHasShallowCopy:
"""Create a shallow copy"""
c = self.__class__.__new__(self.__class__)
self._shallow_copy_to(c)
@@ -246,7 +256,7 @@ class HasCopyInternals(HasTraverseInternals):
setattr(self, attrname, result)
-class _CopyInternalsTraversal(InternalTraversal):
+class _CopyInternalsTraversal(HasTraversalDispatch):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
@@ -381,7 +391,7 @@ def _flatten_clauseelement(element):
return element
-class _GetChildrenTraversal(InternalTraversal):
+class _GetChildrenTraversal(HasTraversalDispatch):
"""Generate a _children_traversal internal traversal dispatch for classes
with a _traverse_internals collection."""
@@ -463,13 +473,13 @@ def _resolve_name_for_compare(element, name, anon_map, **kw):
return name
-class TraversalComparatorStrategy(
- ExtendedInternalTraversal, util.MemoizedSlots
-):
+class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
__slots__ = "stack", "cache", "anon_map"
def __init__(self):
- self.stack = deque()
+ self.stack: Deque[
+ Tuple[ExternallyTraversible, ExternallyTraversible]
+ ] = deque()
self.cache = set()
def _memoized_attr_anon_map(self):
@@ -653,7 +663,7 @@ class TraversalComparatorStrategy(
if seq1 is None:
return seq2 is None
- completed = set()
+ completed: Set[object] = set()
for clause in seq1:
for other_clause in set(seq2).difference(completed):
if self.compare_inner(clause, other_clause, **kw):