diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 26 |
1 files changed, 17 insertions, 9 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index d183372c3..45a2496dd 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -11,6 +11,7 @@ from __future__ import annotations import collections import contextlib +import itertools import re from .. import event @@ -285,7 +286,8 @@ class DialectSQL(CompiledSQL): return received_stmt, execute_observed.context.compiled_parameters - def _dialect_adjusted_statement(self, paramstyle): + def _dialect_adjusted_statement(self, dialect): + paramstyle = dialect.paramstyle stmt = re.sub(r"[\n\t]", "", self.statement) # temporarily escape out PG double colons @@ -300,8 +302,14 @@ class DialectSQL(CompiledSQL): repl = "?" elif paramstyle == "format": repl = r"%s" - elif paramstyle == "numeric": - repl = None + elif paramstyle.startswith("numeric"): + counter = itertools.count(1) + + num_identifier = "$" if paramstyle == "numeric_dollar" else ":" + + def repl(m): + return f"{num_identifier}{next(counter)}" + stmt = re.sub(r":([\w_]+)", repl, stmt) # put them back @@ -310,20 +318,20 @@ class DialectSQL(CompiledSQL): return stmt def _compare_sql(self, execute_observed, received_statement): - paramstyle = execute_observed.context.dialect.paramstyle - stmt = self._dialect_adjusted_statement(paramstyle) + stmt = self._dialect_adjusted_statement( + execute_observed.context.dialect + ) return received_statement == stmt def _failure_message(self, execute_observed, expected_params): - paramstyle = execute_observed.context.dialect.paramstyle return ( "Testing for compiled statement\n%r partial params %s, " "received\n%%(received_statement)r with params " "%%(received_parameters)r" % ( - self._dialect_adjusted_statement(paramstyle).replace( - "%", "%%" - ), + self._dialect_adjusted_statement( + execute_observed.context.dialect + ).replace("%", "%%"), repr(expected_params).replace("%", "%%"), ) ) |