summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py31
1 files changed, 24 insertions, 7 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index ca0bc6726..9816c99b7 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -186,7 +186,9 @@ class CompiledSQL(SQLMatchRule):
self.is_consumed = True
self.errormessage = None
else:
- self.errormessage = self._failure_message(params) % {
+ self.errormessage = self._failure_message(
+ execute_observed, params
+ ) % {
"received_statement": _received_statement,
"received_parameters": _received_parameters,
}
@@ -203,7 +205,7 @@ class CompiledSQL(SQLMatchRule):
else:
return None
- def _failure_message(self, expected_params):
+ def _failure_message(self, execute_observed, expected_params):
return (
"Testing for compiled statement\n%r partial params %s, "
"received\n%%(received_statement)r with params "
@@ -223,7 +225,7 @@ class RegexSQL(CompiledSQL):
self.params = params
self.dialect = dialect
- def _failure_message(self, expected_params):
+ def _failure_message(self, execute_observed, expected_params):
return (
"Testing for compiled statement ~%r partial params %s, "
"received %%(received_statement)r with params "
@@ -263,11 +265,8 @@ class DialectSQL(CompiledSQL):
return received_stmt, execute_observed.context.compiled_parameters
- def _compare_sql(self, execute_observed, received_statement):
+ def _dialect_adjusted_statement(self, paramstyle):
stmt = re.sub(r"[\n\t]", "", self.statement)
- # convert our comparison statement to have the
- # paramstyle of the received
- paramstyle = execute_observed.context.dialect.paramstyle
if paramstyle == "pyformat":
stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
else:
@@ -280,9 +279,27 @@ class DialectSQL(CompiledSQL):
elif paramstyle == "numeric":
repl = None
stmt = re.sub(r":([\w_]+)", repl, stmt)
+ return stmt
+ def _compare_sql(self, execute_observed, received_statement):
+ paramstyle = execute_observed.context.dialect.paramstyle
+ stmt = self._dialect_adjusted_statement(paramstyle)
return received_statement == stmt
+ def _failure_message(self, execute_observed, expected_params):
+ paramstyle = execute_observed.context.dialect.paramstyle
+ return (
+ "Testing for compiled statement\n%r partial params %s, "
+ "received\n%%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (
+ self._dialect_adjusted_statement(paramstyle).replace(
+ "%", "%%"
+ ),
+ repr(expected_params).replace("%", "%%"),
+ )
+ )
+
class CountStatements(AssertRule):
def __init__(self, count):