diff options
Diffstat (limited to 'lib/sqlalchemy/sql/traversals.py')
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 216 |
1 files changed, 173 insertions, 43 deletions
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 8c63fcba1..a308feb7c 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -29,9 +29,8 @@ def compare(obj1, obj2, **kw): return strategy.compare(obj1, obj2, **kw) -class HasCacheKey(HasMemoized): +class HasCacheKey(object): _cache_key_traversal = NO_CACHE - __slots__ = () def _gen_cache_key(self, anon_map, bindparams): @@ -141,7 +140,6 @@ class HasCacheKey(HasMemoized): return result - @HasMemoized.memoized_instancemethod def _generate_cache_key(self): """return a cache key. @@ -183,6 +181,23 @@ class HasCacheKey(HasMemoized): else: return CacheKey(key, bindparams) + @classmethod + def _generate_cache_key_for_object(cls, obj): + bindparams = [] + + _anon_map = anon_map() + key = obj._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + +class MemoizedHasCacheKey(HasCacheKey, HasMemoized): + @HasMemoized.memoized_instancemethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key(self) + class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): def __hash__(self): @@ -191,6 +206,40 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): def __eq__(self, other): return self.key == other.key + def _whats_different(self, other): + + k1 = self.key + k2 = other.key + + stack = [] + pickup_index = 0 + while True: + s1, s2 = k1, k2 + for idx in stack: + s1 = s1[idx] + s2 = s2[idx] + + for idx, (e1, e2) in enumerate(util.zip_longest(s1, s2)): + if idx < pickup_index: + continue + if e1 != e2: + if isinstance(e1, tuple) and isinstance(e2, tuple): + stack.append(idx) + break + else: + yield "key%s[%d]: %s != %s" % ( + "".join("[%d]" % id_ for id_ in stack), + idx, + e1, + e2, + ) + else: + pickup_index = stack.pop(-1) + break + + def _diff(self, other): + return ", ".join(self._whats_different(other)) + def __str__(self): stack = [self.key] @@ -241,9 +290,7 @@ class _CacheKey(ExtendedInternalTraversal): visit_type = STATIC_CACHE_KEY def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): - return self.visit_has_cache_key( - attrname, inspect(obj), parent, anon_map, bindparams - ) + return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): return tuple(obj) @@ -361,6 +408,24 @@ class _CacheKey(ExtendedInternalTraversal): ), ) + def visit_setup_join_tuple( + self, attrname, obj, parent, anon_map, bindparams + ): + # TODO: look at attrname for "legacy_join" and use different structure + return tuple( + ( + target._gen_cache_key(anon_map, bindparams), + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None, + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None, + tuple([(key, flags[key]) for key in sorted(flags)]), + ) + for (target, onclause, from_, flags) in obj + ) + def visit_table_hint_list( self, attrname, obj, parent, anon_map, bindparams ): @@ -498,31 +563,53 @@ class _CopyInternals(InternalTraversal): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" - def visit_clauseelement(self, parent, element, clone=_clone, **kw): + def visit_clauseelement( + self, attrname, parent, element, clone=_clone, **kw + ): return clone(element, **kw) - def visit_clauseelement_list(self, parent, element, clone=_clone, **kw): + def visit_clauseelement_list( + self, attrname, parent, element, clone=_clone, **kw + ): return [clone(clause, **kw) for clause in element] def visit_clauseelement_unordered_set( - self, parent, element, clone=_clone, **kw + self, attrname, parent, element, clone=_clone, **kw ): return {clone(clause, **kw) for clause in element} - def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw): + def visit_clauseelement_tuples( + self, attrname, parent, element, clone=_clone, **kw + ): return [ tuple(clone(tup_elem, **kw) for tup_elem in elem) for elem in element ] def visit_string_clauseelement_dict( - self, parent, element, clone=_clone, **kw + self, attrname, parent, element, clone=_clone, **kw ): return dict( (key, clone(value, **kw)) for key, value in element.items() ) - def visit_dml_ordered_values(self, parent, element, clone=_clone, **kw): + def visit_setup_join_tuple( + self, attrname, parent, element, clone=_clone, **kw + ): + # TODO: look at attrname for "legacy_join" and use different structure + return tuple( + ( + clone(target, **kw) if target is not None else None, + clone(onclause, **kw) if onclause is not None else None, + clone(from_, **kw) if from_ is not None else None, + flags, + ) + for (target, onclause, from_, flags) in element + ) + + def visit_dml_ordered_values( + self, attrname, parent, element, clone=_clone, **kw + ): # sequence of 2-tuples return [ ( @@ -534,7 +621,7 @@ class _CopyInternals(InternalTraversal): for key, value in element ] - def visit_dml_values(self, parent, element, clone=_clone, **kw): + def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw): return { ( clone(key, **kw) if hasattr(key, "__clause_element__") else key @@ -542,7 +629,9 @@ class _CopyInternals(InternalTraversal): for key, value in element.items() } - def visit_dml_multi_values(self, parent, element, clone=_clone, **kw): + def visit_dml_multi_values( + self, attrname, parent, element, clone=_clone, **kw + ): # sequence of sequences, each sequence contains a list/dict/tuple def copy(elem): @@ -741,7 +830,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): continue comparison = dispatch( - left, left_child, right, right_child, **kw + left_attrname, left, left_child, right, right_child, **kw ) if comparison is COMPARE_FAILED: return False @@ -753,31 +842,40 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return comparator.compare(obj1, obj2, **kw) def visit_has_cache_key( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key( self.anon_map[1], [] ): return COMPARE_FAILED + def visit_has_cache_key_list( + self, attrname, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key( + self.anon_map[1], [] + ): + return COMPARE_FAILED + def visit_clauseelement( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): self.stack.append((left, right)) def visit_fromclause_canonical_column_collection( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for lcol, rcol in util.zip_longest(left, right, fillvalue=None): self.stack.append((lcol, rcol)) def visit_fromclause_derived_column_collection( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): pass def visit_string_clauseelement_dict( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for lstr, rstr in util.zip_longest( sorted(left), sorted(right), fillvalue=None @@ -787,7 +885,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((left[lstr], right[rstr])) def visit_clauseelement_tuples( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for ltup, rtup in util.zip_longest(left, right, fillvalue=None): if ltup is None or rtup is None: @@ -797,7 +895,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((l, r)) def visit_clauseelement_list( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in util.zip_longest(left, right, fillvalue=None): self.stack.append((l, r)) @@ -815,48 +913,62 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return len(completed) == len(seq1) == len(seq2) def visit_clauseelement_unordered_set( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): return self._compare_unordered_sequences(left, right, **kw) def visit_fromclause_ordered_set( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in util.zip_longest(left, right, fillvalue=None): self.stack.append((l, r)) - def visit_string(self, left_parent, left, right_parent, right, **kw): + def visit_string( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right - def visit_string_list(self, left_parent, left, right_parent, right, **kw): + def visit_string_list( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right - def visit_anon_name(self, left_parent, left, right_parent, right, **kw): + def visit_anon_name( + self, attrname, left_parent, left, right_parent, right, **kw + ): return _resolve_name_for_compare( left_parent, left, self.anon_map[0], **kw ) == _resolve_name_for_compare( right_parent, right, self.anon_map[1], **kw ) - def visit_boolean(self, left_parent, left, right_parent, right, **kw): + def visit_boolean( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right - def visit_operator(self, left_parent, left, right_parent, right, **kw): + def visit_operator( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left is right - def visit_type(self, left_parent, left, right_parent, right, **kw): + def visit_type( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left._compare_type_affinity(right) - def visit_plain_dict(self, left_parent, left, right_parent, right, **kw): + def visit_plain_dict( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right def visit_dialect_options( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): return left == right def visit_annotations_key( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): if left and right: return ( @@ -866,11 +978,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): else: return left == right - def visit_plain_obj(self, left_parent, left, right_parent, right, **kw): + def visit_plain_obj( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right def visit_named_ddl_element( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): if left is None: if right is not None: @@ -879,7 +993,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return left.name == right.name def visit_prefix_sequence( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for (l_clause, l_str), (r_clause, r_str) in util.zip_longest( left, right, fillvalue=(None, None) @@ -889,8 +1003,22 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): else: self.stack.append((l_clause, r_clause)) + def visit_setup_join_tuple( + self, attrname, left_parent, left, right_parent, right, **kw + ): + # TODO: look at attrname for "legacy_join" and use different structure + for ( + (l_target, l_onclause, l_from, l_flags), + (r_target, r_onclause, r_from, r_flags), + ) in util.zip_longest(left, right, fillvalue=(None, None, None, None)): + if l_flags != r_flags: + return COMPARE_FAILED + self.stack.append((l_target, r_target)) + self.stack.append((l_onclause, r_onclause)) + self.stack.append((l_from, r_from)) + def visit_table_hint_list( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1])) right_keys = sorted( @@ -907,17 +1035,17 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((ltable, rtable)) def visit_statement_hint_list( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): return left == right def visit_unknown_structure( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): raise NotImplementedError() def visit_dml_ordered_values( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): # sequence of tuple pairs @@ -941,7 +1069,9 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return True - def visit_dml_values(self, left_parent, left, right_parent, right, **kw): + def visit_dml_values( + self, attrname, left_parent, left, right_parent, right, **kw + ): if left is None or right is None or len(left) != len(right): return COMPARE_FAILED @@ -961,7 +1091,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return COMPARE_FAILED def visit_dml_multi_values( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for lseq, rseq in util.zip_longest(left, right, fillvalue=None): if lseq is None or rseq is None: @@ -970,7 +1100,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None): if ( self.visit_dml_values( - left_parent, ld, right_parent, rd, **kw + attrname, left_parent, ld, right_parent, rd, **kw ) is COMPARE_FAILED ): |