diff options
Diffstat (limited to 'lib/sqlalchemy/sql/traversals.py')
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index aceed99a5..94e635740 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -19,6 +19,7 @@ from typing import Callable from typing import Deque from typing import Dict from typing import Iterable +from typing import Optional from typing import Set from typing import Tuple from typing import Type @@ -39,7 +40,7 @@ COMPARE_FAILED = False COMPARE_SUCCEEDED = True -def compare(obj1, obj2, **kw): +def compare(obj1: Any, obj2: Any, **kw: Any) -> bool: strategy: TraversalComparatorStrategy if kw.get("use_proxies", False): strategy = ColIdentityComparatorStrategy() @@ -49,7 +50,7 @@ def compare(obj1, obj2, **kw): return strategy.compare(obj1, obj2, **kw) -def _preconfigure_traversals(target_hierarchy): +def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None: for cls in util.walk_subclasses(target_hierarchy): if hasattr(cls, "_generate_cache_attrs") and hasattr( cls, "_traverse_internals" @@ -482,14 +483,22 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): def __init__(self): self.stack: Deque[ - Tuple[ExternallyTraversible, ExternallyTraversible] + Tuple[ + Optional[ExternallyTraversible], + Optional[ExternallyTraversible], + ] ] = deque() self.cache = set() def _memoized_attr_anon_map(self): return (anon_map(), anon_map()) - def compare(self, obj1, obj2, **kw): + def compare( + self, + obj1: ExternallyTraversible, + obj2: ExternallyTraversible, + **kw: Any, + ) -> bool: stack = self.stack cache = self.cache @@ -551,6 +560,10 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): elif left_attrname in attributes_compared: continue + assert left_visit_sym is not None + assert left_attrname is not None + assert right_attrname is not None + dispatch = self.dispatch(left_visit_sym) assert dispatch, ( f"{self.__class__} has no dispatch for " @@ -595,6 +608,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in zip_longest(left, right, fillvalue=None): + if l is None: + if r is not None: + return COMPARE_FAILED + else: + continue + elif r is None: + return COMPARE_FAILED + if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key( self.anon_map[1], [] ): @@ -604,6 +625,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in zip_longest(left, right, fillvalue=None): + if l is None: + if r is not None: + return COMPARE_FAILED + else: + continue + elif r is None: + return COMPARE_FAILED + if ( l._gen_cache_key(self.anon_map[0], []) if l._is_has_cache_key |