diff options
author | Federico Caselli <cfederico87@gmail.com> | 2022-11-19 20:39:10 +0100 |
---|---|---|
committer | Federico Caselli <cfederico87@gmail.com> | 2022-12-01 23:50:30 +0100 |
commit | 0f2baae6bf72353f785bad394684f2d6fa53e0ef (patch) | |
tree | 4d7c2cd6e8a73106aa4f95105968cf6e3fded813 /lib/sqlalchemy/testing/assertions.py | |
parent | c440c920aecd6593974e5a0d37cdb9069e5d3e57 (diff) | |
download | sqlalchemy-0f2baae6bf72353f785bad394684f2d6fa53e0ef.tar.gz |
Fix positional compiling bugs
Fixed a series of issues regarding positionally rendered bound parameters,
such as those used for SQLite, asyncpg, MySQL and others. Some compiled
forms would not maintain the order of parameters correctly, such as the
PostgreSQL ``regexp_replace()`` function as well as within the "nesting"
feature of the :class:`.CTE` construct first introduced in :ticket:`4123`.
Fixes: #8827
Change-Id: I9813ed7c358cc5c1e26725c48df546b209a442cb
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 64 |
1 files changed, 61 insertions, 3 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 321c05b44..790a72ec8 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -9,7 +9,9 @@ from __future__ import annotations +from collections import defaultdict import contextlib +from copy import copy from itertools import filterfalse import re import sys @@ -493,6 +495,7 @@ class AssertsCompiledSQL: render_schema_translate=False, default_schema_name=None, from_linting=False, + check_param_order=True, ): if use_default_dialect: dialect = default.DefaultDialect() @@ -506,8 +509,11 @@ class AssertsCompiledSQL: if dialect is None: dialect = config.db.dialect - elif dialect == "default": - dialect = default.DefaultDialect() + elif dialect == "default" or dialect == "default_qmark": + if dialect == "default": + dialect = default.DefaultDialect() + else: + dialect = default.DefaultDialect("qmark") dialect.supports_default_values = supports_default_values dialect.supports_default_metavalue = supports_default_metavalue elif dialect == "default_enhanced": @@ -632,7 +638,7 @@ class AssertsCompiledSQL: if checkparams is not None: eq_(c.construct_params(params), checkparams) if checkpositional is not None: - p = c.construct_params(params) + p = c.construct_params(params, escape_names=False) eq_(tuple([p[x] for x in c.positiontup]), checkpositional) if check_prefetch is not None: eq_(c.prefetch, check_prefetch) @@ -652,6 +658,58 @@ class AssertsCompiledSQL: }, check_post_param, ) + if check_param_order and getattr(c, "params", None): + + def get_dialect(paramstyle, positional): + cp = copy(dialect) + cp.paramstyle = paramstyle + cp.positional = positional + return cp + + pyformat_dialect = get_dialect("pyformat", False) + pyformat_c = clause.compile(dialect=pyformat_dialect, **kw) + stmt = re.sub(r"[\n\t]", "", str(pyformat_c)) + + qmark_dialect = get_dialect("qmark", True) + qmark_c = clause.compile(dialect=qmark_dialect, **kw) + values = list(qmark_c.positiontup) + escaped = qmark_c.escaped_bind_names + + for post_param in ( + qmark_c.post_compile_params | qmark_c.literal_execute_params + ): + name = qmark_c.bind_names[post_param] + if name in values: + values = [v for v in values if v != name] + positions = [] + pos_by_value = defaultdict(list) + for v in values: + try: + if v in pos_by_value: + start = pos_by_value[v][-1] + else: + start = 0 + esc = escaped.get(v, v) + pos = stmt.index("%%(%s)s" % (esc,), start) + 2 + positions.append(pos) + pos_by_value[v].append(pos) + except ValueError: + msg = "Expected to find bindparam %r in %r" % (v, stmt) + assert False, msg + + ordered = all( + positions[i - 1] < positions[i] + for i in range(1, len(positions)) + ) + + expected = [v for _, v in sorted(zip(positions, values))] + + msg = ( + "Order of parameters %s does not match the order " + "in the statement %s. Statement %r" % (values, expected, stmt) + ) + + is_true(ordered, msg) class ComparesTables: |