diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 64 |
1 files changed, 61 insertions, 3 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 321c05b44..790a72ec8 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -9,7 +9,9 @@ from __future__ import annotations +from collections import defaultdict import contextlib +from copy import copy from itertools import filterfalse import re import sys @@ -493,6 +495,7 @@ class AssertsCompiledSQL: render_schema_translate=False, default_schema_name=None, from_linting=False, + check_param_order=True, ): if use_default_dialect: dialect = default.DefaultDialect() @@ -506,8 +509,11 @@ class AssertsCompiledSQL: if dialect is None: dialect = config.db.dialect - elif dialect == "default": - dialect = default.DefaultDialect() + elif dialect == "default" or dialect == "default_qmark": + if dialect == "default": + dialect = default.DefaultDialect() + else: + dialect = default.DefaultDialect("qmark") dialect.supports_default_values = supports_default_values dialect.supports_default_metavalue = supports_default_metavalue elif dialect == "default_enhanced": @@ -632,7 +638,7 @@ class AssertsCompiledSQL: if checkparams is not None: eq_(c.construct_params(params), checkparams) if checkpositional is not None: - p = c.construct_params(params) + p = c.construct_params(params, escape_names=False) eq_(tuple([p[x] for x in c.positiontup]), checkpositional) if check_prefetch is not None: eq_(c.prefetch, check_prefetch) @@ -652,6 +658,58 @@ class AssertsCompiledSQL: }, check_post_param, ) + if check_param_order and getattr(c, "params", None): + + def get_dialect(paramstyle, positional): + cp = copy(dialect) + cp.paramstyle = paramstyle + cp.positional = positional + return cp + + pyformat_dialect = get_dialect("pyformat", False) + pyformat_c = clause.compile(dialect=pyformat_dialect, **kw) + stmt = re.sub(r"[\n\t]", "", str(pyformat_c)) + + qmark_dialect = get_dialect("qmark", True) + qmark_c = clause.compile(dialect=qmark_dialect, **kw) + values = list(qmark_c.positiontup) + escaped = qmark_c.escaped_bind_names + + for post_param in ( + qmark_c.post_compile_params | qmark_c.literal_execute_params + ): + name = qmark_c.bind_names[post_param] + if name in values: + values = [v for v in values if v != name] + positions = [] + pos_by_value = defaultdict(list) + for v in values: + try: + if v in pos_by_value: + start = pos_by_value[v][-1] + else: + start = 0 + esc = escaped.get(v, v) + pos = stmt.index("%%(%s)s" % (esc,), start) + 2 + positions.append(pos) + pos_by_value[v].append(pos) + except ValueError: + msg = "Expected to find bindparam %r in %r" % (v, stmt) + assert False, msg + + ordered = all( + positions[i - 1] < positions[i] + for i in range(1, len(positions)) + ) + + expected = [v for _, v in sorted(zip(positions, values))] + + msg = ( + "Order of parameters %s does not match the order " + "in the statement %s. Statement %r" % (values, expected, stmt) + ) + + is_true(ordered, msg) class ComparesTables: |