From 18a73fb1d1c267842ead5dacd05a49f4344d8b22 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 6 May 2022 16:09:52 -0400 Subject: revenge of pep 484 trying to get remaining must-haves for ORM Change-Id: I66a3ecbbb8e5ba37c818c8a92737b576ecf012f7 --- lib/sqlalchemy/sql/traversals.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) (limited to 'lib/sqlalchemy/sql/traversals.py') diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index aceed99a5..94e635740 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -19,6 +19,7 @@ from typing import Callable from typing import Deque from typing import Dict from typing import Iterable +from typing import Optional from typing import Set from typing import Tuple from typing import Type @@ -39,7 +40,7 @@ COMPARE_FAILED = False COMPARE_SUCCEEDED = True -def compare(obj1, obj2, **kw): +def compare(obj1: Any, obj2: Any, **kw: Any) -> bool: strategy: TraversalComparatorStrategy if kw.get("use_proxies", False): strategy = ColIdentityComparatorStrategy() @@ -49,7 +50,7 @@ def compare(obj1, obj2, **kw): return strategy.compare(obj1, obj2, **kw) -def _preconfigure_traversals(target_hierarchy): +def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None: for cls in util.walk_subclasses(target_hierarchy): if hasattr(cls, "_generate_cache_attrs") and hasattr( cls, "_traverse_internals" @@ -482,14 +483,22 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): def __init__(self): self.stack: Deque[ - Tuple[ExternallyTraversible, ExternallyTraversible] + Tuple[ + Optional[ExternallyTraversible], + Optional[ExternallyTraversible], + ] ] = deque() self.cache = set() def _memoized_attr_anon_map(self): return (anon_map(), anon_map()) - def compare(self, obj1, obj2, **kw): + def compare( + self, + obj1: ExternallyTraversible, + obj2: ExternallyTraversible, + **kw: Any, + ) -> bool: stack = self.stack cache = self.cache @@ -551,6 +560,10 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): elif left_attrname in attributes_compared: continue + assert left_visit_sym is not None + assert left_attrname is not None + assert right_attrname is not None + dispatch = self.dispatch(left_visit_sym) assert dispatch, ( f"{self.__class__} has no dispatch for " @@ -595,6 +608,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in zip_longest(left, right, fillvalue=None): + if l is None: + if r is not None: + return COMPARE_FAILED + else: + continue + elif r is None: + return COMPARE_FAILED + if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key( self.anon_map[1], [] ): @@ -604,6 +625,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in zip_longest(left, right, fillvalue=None): + if l is None: + if r is not None: + return COMPARE_FAILED + else: + continue + elif r is None: + return COMPARE_FAILED + if ( l._gen_cache_key(self.anon_map[0], []) if l._is_has_cache_key -- cgit v1.2.1