summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertions.py
diff options
context:
space:
mode:
authorFederico Caselli <cfederico87@gmail.com>2022-11-19 20:39:10 +0100
committerFederico Caselli <cfederico87@gmail.com>2022-12-01 23:50:30 +0100
commit0f2baae6bf72353f785bad394684f2d6fa53e0ef (patch)
tree4d7c2cd6e8a73106aa4f95105968cf6e3fded813 /lib/sqlalchemy/testing/assertions.py
parentc440c920aecd6593974e5a0d37cdb9069e5d3e57 (diff)
downloadsqlalchemy-0f2baae6bf72353f785bad394684f2d6fa53e0ef.tar.gz
Fix positional compiling bugs
Fixed a series of issues regarding positionally rendered bound parameters, such as those used for SQLite, asyncpg, MySQL and others. Some compiled forms would not maintain the order of parameters correctly, such as the PostgreSQL ``regexp_replace()`` function as well as within the "nesting" feature of the :class:`.CTE` construct first introduced in :ticket:`4123`. Fixes: #8827 Change-Id: I9813ed7c358cc5c1e26725c48df546b209a442cb
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r--lib/sqlalchemy/testing/assertions.py64
1 files changed, 61 insertions, 3 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index 321c05b44..790a72ec8 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -9,7 +9,9 @@
from __future__ import annotations
+from collections import defaultdict
import contextlib
+from copy import copy
from itertools import filterfalse
import re
import sys
@@ -493,6 +495,7 @@ class AssertsCompiledSQL:
render_schema_translate=False,
default_schema_name=None,
from_linting=False,
+ check_param_order=True,
):
if use_default_dialect:
dialect = default.DefaultDialect()
@@ -506,8 +509,11 @@ class AssertsCompiledSQL:
if dialect is None:
dialect = config.db.dialect
- elif dialect == "default":
- dialect = default.DefaultDialect()
+ elif dialect == "default" or dialect == "default_qmark":
+ if dialect == "default":
+ dialect = default.DefaultDialect()
+ else:
+ dialect = default.DefaultDialect("qmark")
dialect.supports_default_values = supports_default_values
dialect.supports_default_metavalue = supports_default_metavalue
elif dialect == "default_enhanced":
@@ -632,7 +638,7 @@ class AssertsCompiledSQL:
if checkparams is not None:
eq_(c.construct_params(params), checkparams)
if checkpositional is not None:
- p = c.construct_params(params)
+ p = c.construct_params(params, escape_names=False)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
if check_prefetch is not None:
eq_(c.prefetch, check_prefetch)
@@ -652,6 +658,58 @@ class AssertsCompiledSQL:
},
check_post_param,
)
+ if check_param_order and getattr(c, "params", None):
+
+ def get_dialect(paramstyle, positional):
+ cp = copy(dialect)
+ cp.paramstyle = paramstyle
+ cp.positional = positional
+ return cp
+
+ pyformat_dialect = get_dialect("pyformat", False)
+ pyformat_c = clause.compile(dialect=pyformat_dialect, **kw)
+ stmt = re.sub(r"[\n\t]", "", str(pyformat_c))
+
+ qmark_dialect = get_dialect("qmark", True)
+ qmark_c = clause.compile(dialect=qmark_dialect, **kw)
+ values = list(qmark_c.positiontup)
+ escaped = qmark_c.escaped_bind_names
+
+ for post_param in (
+ qmark_c.post_compile_params | qmark_c.literal_execute_params
+ ):
+ name = qmark_c.bind_names[post_param]
+ if name in values:
+ values = [v for v in values if v != name]
+ positions = []
+ pos_by_value = defaultdict(list)
+ for v in values:
+ try:
+ if v in pos_by_value:
+ start = pos_by_value[v][-1]
+ else:
+ start = 0
+ esc = escaped.get(v, v)
+ pos = stmt.index("%%(%s)s" % (esc,), start) + 2
+ positions.append(pos)
+ pos_by_value[v].append(pos)
+ except ValueError:
+ msg = "Expected to find bindparam %r in %r" % (v, stmt)
+ assert False, msg
+
+ ordered = all(
+ positions[i - 1] < positions[i]
+ for i in range(1, len(positions))
+ )
+
+ expected = [v for _, v in sorted(zip(positions, values))]
+
+ msg = (
+ "Order of parameters %s does not match the order "
+ "in the statement %s. Statement %r" % (values, expected, stmt)
+ )
+
+ is_true(ordered, msg)
class ComparesTables: