summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/clause_compare.py
blob: 50b1df99ed6134cffc031753b2fa9d4e3116d332 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
from collections import deque

from . import operators
from .. import util


SKIP_TRAVERSE = util.symbol("skip_traverse")


def compare(obj1, obj2, **kw):
    if kw.get("use_proxies", False):
        strategy = ColIdentityComparatorStrategy()
    else:
        strategy = StructureComparatorStrategy()

    return strategy.compare(obj1, obj2, **kw)


class StructureComparatorStrategy(object):
    __slots__ = "compare_stack", "cache"

    def __init__(self):
        self.compare_stack = deque()
        self.cache = set()

    def compare(self, obj1, obj2, **kw):
        stack = self.compare_stack
        cache = self.cache

        stack.append((obj1, obj2))

        while stack:
            left, right = stack.popleft()

            if left is right:
                continue
            elif left is None or right is None:
                # we know they are different so no match
                return False
            elif (left, right) in cache:
                continue
            cache.add((left, right))

            visit_name = left.__visit_name__

            # we're not exactly looking for identical types, because
            # there are things like Column and AnnotatedColumn.  So the
            # visit_name has to at least match up
            if visit_name != right.__visit_name__:
                return False

            meth = getattr(self, "compare_%s" % visit_name, None)

            if meth:
                comparison = meth(left, right, **kw)
                if comparison is False:
                    return False
                elif comparison is SKIP_TRAVERSE:
                    continue

            for c1, c2 in util.zip_longest(
                left.get_children(column_collections=False),
                right.get_children(column_collections=False),
                fillvalue=None,
            ):
                if c1 is None or c2 is None:
                    # collections are different sizes, comparison fails
                    return False
                stack.append((c1, c2))

        return True

    def compare_inner(self, obj1, obj2, **kw):
        stack = self.compare_stack
        try:
            self.compare_stack = deque()
            return self.compare(obj1, obj2, **kw)
        finally:
            self.compare_stack = stack

    def _compare_unordered_sequences(self, seq1, seq2, **kw):
        if seq1 is None:
            return seq2 is None

        completed = set()
        for clause in seq1:
            for other_clause in set(seq2).difference(completed):
                if self.compare_inner(clause, other_clause, **kw):
                    completed.add(other_clause)
                    break
        return len(completed) == len(seq1) == len(seq2)

    def compare_bindparam(self, left, right, **kw):
        # note the ".key" is often generated from id(self) so can't
        # be compared, as far as determining structure.
        return (
            left.type._compare_type_affinity(right.type)
            and left.value == right.value
            and left.callable == right.callable
            and left._orig_key == right._orig_key
        )

    def compare_clauselist(self, left, right, **kw):
        if left.operator is right.operator:
            if operators.is_associative(left.operator):
                if self._compare_unordered_sequences(
                    left.clauses, right.clauses
                ):
                    return SKIP_TRAVERSE
                else:
                    return False
            else:
                # normal ordered traversal
                return True
        else:
            return False

    def compare_unary(self, left, right, **kw):
        if left.operator:
            disp = self._get_operator_dispatch(
                left.operator, "unary", "operator"
            )
            if disp is not None:
                result = disp(left, right, left.operator, **kw)
                if result is not True:
                    return result
        elif left.modifier:
            disp = self._get_operator_dispatch(
                left.modifier, "unary", "modifier"
            )
            if disp is not None:
                result = disp(left, right, left.operator, **kw)
                if result is not True:
                    return result
        return (
            left.operator == right.operator and left.modifier == right.modifier
        )

    def compare_binary(self, left, right, **kw):
        disp = self._get_operator_dispatch(left.operator, "binary", None)
        if disp:
            result = disp(left, right, left.operator, **kw)
            if result is not True:
                return result

        if left.operator == right.operator:
            if operators.is_commutative(left.operator):
                if (
                    compare(left.left, right.left, **kw)
                    and compare(left.right, right.right, **kw)
                ) or (
                    compare(left.left, right.right, **kw)
                    and compare(left.right, right.left, **kw)
                ):
                    return SKIP_TRAVERSE
                else:
                    return False
            else:
                return True
        else:
            return False

    def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
        # used by compare_binary, compare_unary
        attrname = "visit_%s_%s%s" % (
            operator_.__name__,
            qualifier1,
            "_" + qualifier2 if qualifier2 else "",
        )
        return getattr(self, attrname, None)

    def visit_function_as_comparison_op_binary(
        self, left, right, operator, **kw
    ):
        return (
            left.left_index == right.left_index
            and left.right_index == right.right_index
        )

    def compare_function(self, left, right, **kw):
        return left.name == right.name

    def compare_column(self, left, right, **kw):
        if left.table is not None:
            self.compare_stack.appendleft((left.table, right.table))
        return (
            left.key == right.key
            and left.name == right.name
            and (
                left.type._compare_type_affinity(right.type)
                if left.type is not None
                else right.type is None
            )
            and left.is_literal == right.is_literal
        )

    def compare_collation(self, left, right, **kw):
        return left.collation == right.collation

    def compare_type_coerce(self, left, right, **kw):
        return left.type._compare_type_affinity(right.type)

    @util.dependencies("sqlalchemy.sql.elements")
    def compare_alias(self, elements, left, right, **kw):
        return (
            left.name == right.name
            if not isinstance(left.name, elements._anonymous_label)
            else isinstance(right.name, elements._anonymous_label)
        )

    def compare_extract(self, left, right, **kw):
        return left.field == right.field

    def compare_textual_label_reference(self, left, right, **kw):
        return left.element == right.element

    def compare_slice(self, left, right, **kw):
        return (
            left.start == right.start
            and left.stop == right.stop
            and left.step == right.step
        )

    def compare_over(self, left, right, **kw):
        return left.range_ == right.range_ and left.rows == right.rows

    @util.dependencies("sqlalchemy.sql.elements")
    def compare_label(self, elements, left, right, **kw):
        return left._type._compare_type_affinity(right._type) and (
            left.name == right.name
            if not isinstance(left.name, elements._anonymous_label)
            else isinstance(right.name, elements._anonymous_label)
        )

    def compare_typeclause(self, left, right, **kw):
        return left.type._compare_type_affinity(right.type)

    def compare_join(self, left, right, **kw):
        return left.isouter == right.isouter and left.full == right.full

    def compare_table(self, left, right, **kw):
        if left.name != right.name:
            return False

        self.compare_stack.extendleft(
            util.zip_longest(left.columns, right.columns)
        )

    def compare_compound_select(self, left, right, **kw):

        if not self._compare_unordered_sequences(
            left.selects, right.selects, **kw
        ):
            return False

        if left.keyword != right.keyword:
            return False

        if left._for_update_arg != right._for_update_arg:
            return False

        if not self.compare_inner(
            left._order_by_clause, right._order_by_clause, **kw
        ):
            return False

        if not self.compare_inner(
            left._group_by_clause, right._group_by_clause, **kw
        ):
            return False

        return SKIP_TRAVERSE

    def compare_select(self, left, right, **kw):
        if not self._compare_unordered_sequences(
            left._correlate, right._correlate
        ):
            return False
        if not self._compare_unordered_sequences(
            left._correlate_except, right._correlate_except
        ):
            return False

        if not self._compare_unordered_sequences(
            left._from_obj, right._from_obj
        ):
            return False

        if left._for_update_arg != right._for_update_arg:
            return False

        return True

    def compare_textual_select(self, left, right, **kw):
        self.compare_stack.extendleft(
            util.zip_longest(left.column_args, right.column_args)
        )
        return left.positional == right.positional


class ColIdentityComparatorStrategy(StructureComparatorStrategy):
    def compare_column_element(
        self, left, right, use_proxies=True, equivalents=(), **kw
    ):
        """Compare ColumnElements using proxies and equivalent collections.

        This is a comparison strategy specific to the ORM.
        """

        to_compare = (right,)
        if equivalents and right in equivalents:
            to_compare = equivalents[right].union(to_compare)

        for oth in to_compare:
            if use_proxies and left.shares_lineage(oth):
                return True
            elif hash(left) == hash(right):
                return True
        else:
            return False

    def compare_column(self, left, right, **kw):
        return self.compare_column_element(left, right, **kw)

    def compare_label(self, left, right, **kw):
        return self.compare_column_element(left, right, **kw)

    def compare_table(self, left, right, **kw):
        # tables compare on identity, since it's not really feasible to
        # compare them column by column with the above rules
        return left is right