summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py26
1 files changed, 17 insertions, 9 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index d183372c3..45a2496dd 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -11,6 +11,7 @@ from __future__ import annotations
import collections
import contextlib
+import itertools
import re
from .. import event
@@ -285,7 +286,8 @@ class DialectSQL(CompiledSQL):
return received_stmt, execute_observed.context.compiled_parameters
- def _dialect_adjusted_statement(self, paramstyle):
+ def _dialect_adjusted_statement(self, dialect):
+ paramstyle = dialect.paramstyle
stmt = re.sub(r"[\n\t]", "", self.statement)
# temporarily escape out PG double colons
@@ -300,8 +302,14 @@ class DialectSQL(CompiledSQL):
repl = "?"
elif paramstyle == "format":
repl = r"%s"
- elif paramstyle == "numeric":
- repl = None
+ elif paramstyle.startswith("numeric"):
+ counter = itertools.count(1)
+
+ num_identifier = "$" if paramstyle == "numeric_dollar" else ":"
+
+ def repl(m):
+ return f"{num_identifier}{next(counter)}"
+
stmt = re.sub(r":([\w_]+)", repl, stmt)
# put them back
@@ -310,20 +318,20 @@ class DialectSQL(CompiledSQL):
return stmt
def _compare_sql(self, execute_observed, received_statement):
- paramstyle = execute_observed.context.dialect.paramstyle
- stmt = self._dialect_adjusted_statement(paramstyle)
+ stmt = self._dialect_adjusted_statement(
+ execute_observed.context.dialect
+ )
return received_statement == stmt
def _failure_message(self, execute_observed, expected_params):
- paramstyle = execute_observed.context.dialect.paramstyle
return (
"Testing for compiled statement\n%r partial params %s, "
"received\n%%(received_statement)r with params "
"%%(received_parameters)r"
% (
- self._dialect_adjusted_statement(paramstyle).replace(
- "%", "%%"
- ),
+ self._dialect_adjusted_statement(
+ execute_observed.context.dialect
+ ).replace("%", "%%"),
repr(expected_params).replace("%", "%%"),
)
)