summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r--lib/sqlalchemy/testing/assertions.py64
1 files changed, 61 insertions, 3 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index 321c05b44..790a72ec8 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -9,7 +9,9 @@
from __future__ import annotations
+from collections import defaultdict
import contextlib
+from copy import copy
from itertools import filterfalse
import re
import sys
@@ -493,6 +495,7 @@ class AssertsCompiledSQL:
render_schema_translate=False,
default_schema_name=None,
from_linting=False,
+ check_param_order=True,
):
if use_default_dialect:
dialect = default.DefaultDialect()
@@ -506,8 +509,11 @@ class AssertsCompiledSQL:
if dialect is None:
dialect = config.db.dialect
- elif dialect == "default":
- dialect = default.DefaultDialect()
+ elif dialect == "default" or dialect == "default_qmark":
+ if dialect == "default":
+ dialect = default.DefaultDialect()
+ else:
+ dialect = default.DefaultDialect("qmark")
dialect.supports_default_values = supports_default_values
dialect.supports_default_metavalue = supports_default_metavalue
elif dialect == "default_enhanced":
@@ -632,7 +638,7 @@ class AssertsCompiledSQL:
if checkparams is not None:
eq_(c.construct_params(params), checkparams)
if checkpositional is not None:
- p = c.construct_params(params)
+ p = c.construct_params(params, escape_names=False)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
if check_prefetch is not None:
eq_(c.prefetch, check_prefetch)
@@ -652,6 +658,58 @@ class AssertsCompiledSQL:
},
check_post_param,
)
+ if check_param_order and getattr(c, "params", None):
+
+ def get_dialect(paramstyle, positional):
+ cp = copy(dialect)
+ cp.paramstyle = paramstyle
+ cp.positional = positional
+ return cp
+
+ pyformat_dialect = get_dialect("pyformat", False)
+ pyformat_c = clause.compile(dialect=pyformat_dialect, **kw)
+ stmt = re.sub(r"[\n\t]", "", str(pyformat_c))
+
+ qmark_dialect = get_dialect("qmark", True)
+ qmark_c = clause.compile(dialect=qmark_dialect, **kw)
+ values = list(qmark_c.positiontup)
+ escaped = qmark_c.escaped_bind_names
+
+ for post_param in (
+ qmark_c.post_compile_params | qmark_c.literal_execute_params
+ ):
+ name = qmark_c.bind_names[post_param]
+ if name in values:
+ values = [v for v in values if v != name]
+ positions = []
+ pos_by_value = defaultdict(list)
+ for v in values:
+ try:
+ if v in pos_by_value:
+ start = pos_by_value[v][-1]
+ else:
+ start = 0
+ esc = escaped.get(v, v)
+ pos = stmt.index("%%(%s)s" % (esc,), start) + 2
+ positions.append(pos)
+ pos_by_value[v].append(pos)
+ except ValueError:
+ msg = "Expected to find bindparam %r in %r" % (v, stmt)
+ assert False, msg
+
+ ordered = all(
+ positions[i - 1] < positions[i]
+ for i in range(1, len(positions))
+ )
+
+ expected = [v for _, v in sorted(zip(positions, values))]
+
+ msg = (
+ "Order of parameters %s does not match the order "
+ "in the statement %s. Statement %r" % (values, expected, stmt)
+ )
+
+ is_true(ordered, msg)
class ComparesTables: