summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/traversals.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/traversals.py')
-rw-r--r--lib/sqlalchemy/sql/traversals.py216
1 files changed, 173 insertions, 43 deletions
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 8c63fcba1..a308feb7c 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -29,9 +29,8 @@ def compare(obj1, obj2, **kw):
return strategy.compare(obj1, obj2, **kw)
-class HasCacheKey(HasMemoized):
+class HasCacheKey(object):
_cache_key_traversal = NO_CACHE
-
__slots__ = ()
def _gen_cache_key(self, anon_map, bindparams):
@@ -141,7 +140,6 @@ class HasCacheKey(HasMemoized):
return result
- @HasMemoized.memoized_instancemethod
def _generate_cache_key(self):
"""return a cache key.
@@ -183,6 +181,23 @@ class HasCacheKey(HasMemoized):
else:
return CacheKey(key, bindparams)
+ @classmethod
+ def _generate_cache_key_for_object(cls, obj):
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = obj._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+
+class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
+ @HasMemoized.memoized_instancemethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key(self)
+
class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
def __hash__(self):
@@ -191,6 +206,40 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
def __eq__(self, other):
return self.key == other.key
+ def _whats_different(self, other):
+
+ k1 = self.key
+ k2 = other.key
+
+ stack = []
+ pickup_index = 0
+ while True:
+ s1, s2 = k1, k2
+ for idx in stack:
+ s1 = s1[idx]
+ s2 = s2[idx]
+
+ for idx, (e1, e2) in enumerate(util.zip_longest(s1, s2)):
+ if idx < pickup_index:
+ continue
+ if e1 != e2:
+ if isinstance(e1, tuple) and isinstance(e2, tuple):
+ stack.append(idx)
+ break
+ else:
+ yield "key%s[%d]: %s != %s" % (
+ "".join("[%d]" % id_ for id_ in stack),
+ idx,
+ e1,
+ e2,
+ )
+ else:
+ pickup_index = stack.pop(-1)
+ break
+
+ def _diff(self, other):
+ return ", ".join(self._whats_different(other))
+
def __str__(self):
stack = [self.key]
@@ -241,9 +290,7 @@ class _CacheKey(ExtendedInternalTraversal):
visit_type = STATIC_CACHE_KEY
def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
- return self.visit_has_cache_key(
- attrname, inspect(obj), parent, anon_map, bindparams
- )
+ return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
return tuple(obj)
@@ -361,6 +408,24 @@ class _CacheKey(ExtendedInternalTraversal):
),
)
+ def visit_setup_join_tuple(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ return tuple(
+ (
+ target._gen_cache_key(anon_map, bindparams),
+ onclause._gen_cache_key(anon_map, bindparams)
+ if onclause is not None
+ else None,
+ from_._gen_cache_key(anon_map, bindparams)
+ if from_ is not None
+ else None,
+ tuple([(key, flags[key]) for key in sorted(flags)]),
+ )
+ for (target, onclause, from_, flags) in obj
+ )
+
def visit_table_hint_list(
self, attrname, obj, parent, anon_map, bindparams
):
@@ -498,31 +563,53 @@ class _CopyInternals(InternalTraversal):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
- def visit_clauseelement(self, parent, element, clone=_clone, **kw):
+ def visit_clauseelement(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
return clone(element, **kw)
- def visit_clauseelement_list(self, parent, element, clone=_clone, **kw):
+ def visit_clauseelement_list(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
return [clone(clause, **kw) for clause in element]
def visit_clauseelement_unordered_set(
- self, parent, element, clone=_clone, **kw
+ self, attrname, parent, element, clone=_clone, **kw
):
return {clone(clause, **kw) for clause in element}
- def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw):
+ def visit_clauseelement_tuples(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
return [
tuple(clone(tup_elem, **kw) for tup_elem in elem)
for elem in element
]
def visit_string_clauseelement_dict(
- self, parent, element, clone=_clone, **kw
+ self, attrname, parent, element, clone=_clone, **kw
):
return dict(
(key, clone(value, **kw)) for key, value in element.items()
)
- def visit_dml_ordered_values(self, parent, element, clone=_clone, **kw):
+ def visit_setup_join_tuple(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ return tuple(
+ (
+ clone(target, **kw) if target is not None else None,
+ clone(onclause, **kw) if onclause is not None else None,
+ clone(from_, **kw) if from_ is not None else None,
+ flags,
+ )
+ for (target, onclause, from_, flags) in element
+ )
+
+ def visit_dml_ordered_values(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
# sequence of 2-tuples
return [
(
@@ -534,7 +621,7 @@ class _CopyInternals(InternalTraversal):
for key, value in element
]
- def visit_dml_values(self, parent, element, clone=_clone, **kw):
+ def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
return {
(
clone(key, **kw) if hasattr(key, "__clause_element__") else key
@@ -542,7 +629,9 @@ class _CopyInternals(InternalTraversal):
for key, value in element.items()
}
- def visit_dml_multi_values(self, parent, element, clone=_clone, **kw):
+ def visit_dml_multi_values(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
# sequence of sequences, each sequence contains a list/dict/tuple
def copy(elem):
@@ -741,7 +830,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
continue
comparison = dispatch(
- left, left_child, right, right_child, **kw
+ left_attrname, left, left_child, right, right_child, **kw
)
if comparison is COMPARE_FAILED:
return False
@@ -753,31 +842,40 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return comparator.compare(obj1, obj2, **kw)
def visit_has_cache_key(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
self.anon_map[1], []
):
return COMPARE_FAILED
+ def visit_has_cache_key_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
+ self.anon_map[1], []
+ ):
+ return COMPARE_FAILED
+
def visit_clauseelement(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
self.stack.append((left, right))
def visit_fromclause_canonical_column_collection(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
self.stack.append((lcol, rcol))
def visit_fromclause_derived_column_collection(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
pass
def visit_string_clauseelement_dict(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for lstr, rstr in util.zip_longest(
sorted(left), sorted(right), fillvalue=None
@@ -787,7 +885,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((left[lstr], right[rstr]))
def visit_clauseelement_tuples(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
if ltup is None or rtup is None:
@@ -797,7 +895,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((l, r))
def visit_clauseelement_list(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in util.zip_longest(left, right, fillvalue=None):
self.stack.append((l, r))
@@ -815,48 +913,62 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return len(completed) == len(seq1) == len(seq2)
def visit_clauseelement_unordered_set(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
return self._compare_unordered_sequences(left, right, **kw)
def visit_fromclause_ordered_set(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in util.zip_longest(left, right, fillvalue=None):
self.stack.append((l, r))
- def visit_string(self, left_parent, left, right_parent, right, **kw):
+ def visit_string(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
- def visit_string_list(self, left_parent, left, right_parent, right, **kw):
+ def visit_string_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
- def visit_anon_name(self, left_parent, left, right_parent, right, **kw):
+ def visit_anon_name(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return _resolve_name_for_compare(
left_parent, left, self.anon_map[0], **kw
) == _resolve_name_for_compare(
right_parent, right, self.anon_map[1], **kw
)
- def visit_boolean(self, left_parent, left, right_parent, right, **kw):
+ def visit_boolean(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
- def visit_operator(self, left_parent, left, right_parent, right, **kw):
+ def visit_operator(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left is right
- def visit_type(self, left_parent, left, right_parent, right, **kw):
+ def visit_type(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left._compare_type_affinity(right)
- def visit_plain_dict(self, left_parent, left, right_parent, right, **kw):
+ def visit_plain_dict(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
def visit_dialect_options(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
return left == right
def visit_annotations_key(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
if left and right:
return (
@@ -866,11 +978,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
else:
return left == right
- def visit_plain_obj(self, left_parent, left, right_parent, right, **kw):
+ def visit_plain_obj(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
def visit_named_ddl_element(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
if left is None:
if right is not None:
@@ -879,7 +993,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return left.name == right.name
def visit_prefix_sequence(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
left, right, fillvalue=(None, None)
@@ -889,8 +1003,22 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
else:
self.stack.append((l_clause, r_clause))
+ def visit_setup_join_tuple(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ for (
+ (l_target, l_onclause, l_from, l_flags),
+ (r_target, r_onclause, r_from, r_flags),
+ ) in util.zip_longest(left, right, fillvalue=(None, None, None, None)):
+ if l_flags != r_flags:
+ return COMPARE_FAILED
+ self.stack.append((l_target, r_target))
+ self.stack.append((l_onclause, r_onclause))
+ self.stack.append((l_from, r_from))
+
def visit_table_hint_list(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
right_keys = sorted(
@@ -907,17 +1035,17 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((ltable, rtable))
def visit_statement_hint_list(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
return left == right
def visit_unknown_structure(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
raise NotImplementedError()
def visit_dml_ordered_values(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
# sequence of tuple pairs
@@ -941,7 +1069,9 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return True
- def visit_dml_values(self, left_parent, left, right_parent, right, **kw):
+ def visit_dml_values(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
if left is None or right is None or len(left) != len(right):
return COMPARE_FAILED
@@ -961,7 +1091,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return COMPARE_FAILED
def visit_dml_multi_values(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for lseq, rseq in util.zip_longest(left, right, fillvalue=None):
if lseq is None or rseq is None:
@@ -970,7 +1100,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None):
if (
self.visit_dml_values(
- left_parent, ld, right_parent, rd, **kw
+ attrname, left_parent, ld, right_parent, rd, **kw
)
is COMPARE_FAILED
):