summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/traversals.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-04-01 18:31:16 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-04-01 19:25:23 -0400
commit49b6c50016c8a038a6df7104560bb3945debe064 (patch)
tree9b5b6b9ad6a6aba5374768afd52783fd8c2170f3 /lib/sqlalchemy/sql/traversals.py
parenta9b62055bfa61c11e9fe0b2984437e2c3e32bf0e (diff)
downloadsqlalchemy-49b6c50016c8a038a6df7104560bb3945debe064.tar.gz
Repair caching / traversals for values
The test suite wasn't running the copy_internals most fixtures, enable that and try to get all cases working. Set up selectable.values to do tuple conversion within compilation step. at the same time, disable caching for selectable.values for the moment and make it equivalent to dml_multi_values. fix cache / compare / copy cases for dml_values and dml_multi_values which weren't fully tested or covered. Change-Id: I484ca6e9cb2b66c2e6a321698f2abc0838db1460
Diffstat (limited to 'lib/sqlalchemy/sql/traversals.py')
-rw-r--r--lib/sqlalchemy/sql/traversals.py77
1 files changed, 39 insertions, 38 deletions
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 9ac6cda97..032488826 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -7,6 +7,7 @@ from .visitors import ExtendedInternalTraversal
from .visitors import InternalTraversal
from .. import util
from ..inspection import inspect
+from ..util import collections_abc
from ..util import HasMemoized
SKIP_TRAVERSE = util.symbol("skip_traverse")
@@ -533,18 +534,12 @@ class _CopyInternals(InternalTraversal):
]
def visit_dml_values(self, parent, element, clone=_clone, **kw):
- # sequence of dictionaries
- return [
- {
- (
- clone(key, **kw)
- if hasattr(key, "__clause_element__")
- else key
- ): clone(value, **kw)
- for key, value in sub_element.items()
- }
- for sub_element in element
- ]
+ return {
+ (
+ clone(key, **kw) if hasattr(key, "__clause_element__") else key
+ ): clone(value, **kw)
+ for key, value in element.items()
+ }
def visit_dml_multi_values(self, parent, element, clone=_clone, **kw):
# sequence of sequences, each sequence contains a list/dict/tuple
@@ -552,15 +547,10 @@ class _CopyInternals(InternalTraversal):
def copy(elem):
if isinstance(elem, (list, tuple)):
return [
- (
- clone(key, **kw)
- if hasattr(key, "__clause_element__")
- else key,
- clone(value, **kw)
- if hasattr(value, "__clause_element__")
- else value,
- )
- for key, value in elem
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value
+ for value in elem
]
elif isinstance(elem, dict):
return {
@@ -573,7 +563,7 @@ class _CopyInternals(InternalTraversal):
if hasattr(value, "__clause_element__")
else value
)
- for key, value in elem
+ for key, value in elem.items()
}
else:
# TODO: use abc classes
@@ -939,30 +929,41 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
for (lk, lv), (rk, rv) in util.zip_longest(
left, right, fillvalue=(None, None)
):
- lkce = hasattr(lk, "__clause_element__")
- rkce = hasattr(rk, "__clause_element__")
- if lkce != rkce:
- return COMPARE_FAILED
- elif lkce and not self.compare_inner(lk, rk, **kw):
- return COMPARE_FAILED
- elif not lkce and lk != rk:
- return COMPARE_FAILED
- elif not self.compare_inner(lv, rv, **kw):
+ if not self._compare_dml_values_or_ce(lk, rk, **kw):
return COMPARE_FAILED
+ def _compare_dml_values_or_ce(self, lv, rv, **kw):
+ lvce = hasattr(lv, "__clause_element__")
+ rvce = hasattr(rv, "__clause_element__")
+ if lvce != rvce:
+ return False
+ elif lvce and not self.compare_inner(lv, rv, **kw):
+ return False
+ elif not lvce and lv != rv:
+ return False
+ elif not self.compare_inner(lv, rv, **kw):
+ return False
+
+ return True
+
def visit_dml_values(self, left_parent, left, right_parent, right, **kw):
if left is None or right is None or len(left) != len(right):
return COMPARE_FAILED
- for lk in left:
- lv = left[lk]
+ if isinstance(left, collections_abc.Sequence):
+ for lv, rv in zip(left, right):
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
+ else:
+ for lk in left:
+ lv = left[lk]
- if lk not in right:
- return COMPARE_FAILED
- rv = right[lk]
+ if lk not in right:
+ return COMPARE_FAILED
+ rv = right[lk]
- if not self.compare_inner(lv, rv, **kw):
- return COMPARE_FAILED
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
def visit_dml_multi_values(
self, left_parent, left, right_parent, right, **kw