summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py148
1 files changed, 80 insertions, 68 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 7a525589d..d8e924cb6 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -26,8 +26,10 @@ class AssertRule(object):
pass
def no_more_statements(self):
- assert False, 'All statements are complete, but pending '\
- 'assertion rules remain'
+ assert False, (
+ "All statements are complete, but pending "
+ "assertion rules remain"
+ )
class SQLMatchRule(AssertRule):
@@ -44,12 +46,17 @@ class CursorSQL(SQLMatchRule):
def process_statement(self, execute_observed):
stmt = execute_observed.statements[0]
if self.statement != stmt.statement or (
- self.params is not None and self.params != stmt.parameters):
- self.errormessage = \
- "Testing for exact SQL %s parameters %s received %s %s" % (
- self.statement, self.params,
- stmt.statement, stmt.parameters
+ self.params is not None and self.params != stmt.parameters
+ ):
+ self.errormessage = (
+ "Testing for exact SQL %s parameters %s received %s %s"
+ % (
+ self.statement,
+ self.params,
+ stmt.statement,
+ stmt.parameters,
)
+ )
else:
execute_observed.statements.pop(0)
self.is_consumed = True
@@ -58,23 +65,22 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
-
- def __init__(self, statement, params=None, dialect='default'):
+ def __init__(self, statement, params=None, dialect="default"):
self.statement = statement
self.params = params
self.dialect = dialect
def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
+ stmt = re.sub(r"[\n\t]", "", self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
- if self.dialect == 'default':
+ if self.dialect == "default":
return DefaultDialect()
else:
# ugh
- if self.dialect == 'postgresql':
- params = {'implicit_returning': True}
+ if self.dialect == "postgresql":
+ params = {"implicit_returning": True}
else:
params = {}
return url.URL(self.dialect).get_dialect()(**params)
@@ -86,36 +92,39 @@ class CompiledSQL(SQLMatchRule):
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
if isinstance(context.compiled.statement, _DDLCompiles):
- compiled = \
- context.compiled.statement.compile(
- dialect=compare_dialect,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
+ compiled = context.compiled.statement.compile(
+ dialect=compare_dialect,
+ schema_translate_map=context.execution_options.get(
+ "schema_translate_map"
+ ),
+ )
else:
- compiled = (
- context.compiled.statement.compile(
- dialect=compare_dialect,
- column_keys=context.compiled.column_keys,
- inline=context.compiled.inline,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
+ compiled = context.compiled.statement.compile(
+ dialect=compare_dialect,
+ column_keys=context.compiled.column_keys,
+ inline=context.compiled.inline,
+ schema_translate_map=context.execution_options.get(
+ "schema_translate_map"
+ ),
)
- _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
+ _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
parameters = execute_observed.parameters
if not parameters:
_received_parameters = [compiled.construct_params()]
else:
_received_parameters = [
- compiled.construct_params(m) for m in parameters]
+ compiled.construct_params(m) for m in parameters
+ ]
return _received_statement, _received_parameters
def process_statement(self, execute_observed):
context = execute_observed.context
- _received_statement, _received_parameters = \
- self._received_statement(execute_observed)
+ _received_statement, _received_parameters = self._received_statement(
+ execute_observed
+ )
params = self._all_params(context)
equivalent = self._compare_sql(execute_observed, _received_statement)
@@ -132,8 +141,10 @@ class CompiledSQL(SQLMatchRule):
for param_key in param:
# a key in param did not match current
# 'received'
- if param_key not in received or \
- received[param_key] != param[param_key]:
+ if (
+ param_key not in received
+ or received[param_key] != param[param_key]
+ ):
break
else:
# all keys in param matched 'received';
@@ -153,8 +164,8 @@ class CompiledSQL(SQLMatchRule):
self.errormessage = None
else:
self.errormessage = self._failure_message(params) % {
- 'received_statement': _received_statement,
- 'received_parameters': _received_parameters
+ "received_statement": _received_statement,
+ "received_parameters": _received_parameters,
}
def _all_params(self, context):
@@ -171,11 +182,10 @@ class CompiledSQL(SQLMatchRule):
def _failure_message(self, expected_params):
return (
- 'Testing for compiled statement %r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.statement.replace('%', '%%'), expected_params
- )
+ "Testing for compiled statement %r partial params %r, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (self.statement.replace("%", "%%"), expected_params)
)
@@ -185,15 +195,13 @@ class RegexSQL(CompiledSQL):
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
- self.dialect = 'default'
+ self.dialect = "default"
def _failure_message(self, expected_params):
return (
- 'Testing for compiled statement ~%r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.orig_regex, expected_params
- )
+ "Testing for compiled statement ~%r partial params %r, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r" % (self.orig_regex, expected_params)
)
def _compare_sql(self, execute_observed, received_statement):
@@ -205,12 +213,13 @@ class DialectSQL(CompiledSQL):
return execute_observed.context.dialect
def _compare_no_space(self, real_stmt, received_stmt):
- stmt = re.sub(r'[\n\t]', '', real_stmt)
+ stmt = re.sub(r"[\n\t]", "", real_stmt)
return received_stmt == stmt
def _received_statement(self, execute_observed):
- received_stmt, received_params = super(DialectSQL, self).\
- _received_statement(execute_observed)
+ received_stmt, received_params = super(
+ DialectSQL, self
+ )._received_statement(execute_observed)
# TODO: why do we need this part?
for real_stmt in execute_observed.statements:
@@ -219,34 +228,33 @@ class DialectSQL(CompiledSQL):
else:
raise AssertionError(
"Can't locate compiled statement %r in list of "
- "statements actually invoked" % received_stmt)
+ "statements actually invoked" % received_stmt
+ )
return received_stmt, execute_observed.context.compiled_parameters
def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
+ stmt = re.sub(r"[\n\t]", "", self.statement)
# convert our comparison statement to have the
# paramstyle of the received
paramstyle = execute_observed.context.dialect.paramstyle
- if paramstyle == 'pyformat':
- stmt = re.sub(
- r':([\w_]+)', r"%(\1)s", stmt)
+ if paramstyle == "pyformat":
+ stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
else:
# positional params
repl = None
- if paramstyle == 'qmark':
+ if paramstyle == "qmark":
repl = "?"
- elif paramstyle == 'format':
+ elif paramstyle == "format":
repl = r"%s"
- elif paramstyle == 'numeric':
+ elif paramstyle == "numeric":
repl = None
- stmt = re.sub(r':([\w_]+)', repl, stmt)
+ stmt = re.sub(r":([\w_]+)", repl, stmt)
return received_statement == stmt
class CountStatements(AssertRule):
-
def __init__(self, count):
self.count = count
self._statement_count = 0
@@ -256,12 +264,13 @@ class CountStatements(AssertRule):
def no_more_statements(self):
if self.count != self._statement_count:
- assert False, 'desired statement count %d does not match %d' \
- % (self.count, self._statement_count)
+ assert False, "desired statement count %d does not match %d" % (
+ self.count,
+ self._statement_count,
+ )
class AllOf(AssertRule):
-
def __init__(self, *rules):
self.rules = set(rules)
@@ -283,7 +292,6 @@ class AllOf(AssertRule):
class EachOf(AssertRule):
-
def __init__(self, *rules):
self.rules = list(rules)
@@ -309,7 +317,6 @@ class EachOf(AssertRule):
class Or(AllOf):
-
def process_statement(self, execute_observed):
for rule in self.rules:
rule.process_statement(execute_observed)
@@ -331,7 +338,8 @@ class SQLExecuteObserved(object):
class SQLCursorExecuteObserved(
collections.namedtuple(
"SQLCursorExecuteObserved",
- ["statement", "parameters", "context", "executemany"])
+ ["statement", "parameters", "context", "executemany"],
+ )
):
pass
@@ -374,21 +382,25 @@ def assert_engine(engine):
orig[:] = clauseelement, multiparams, params
@event.listens_for(engine, "after_cursor_execute")
- def cursor_execute(conn, cursor, statement, parameters,
- context, executemany):
+ def cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
if not context:
return
# then grab real cursor statements and associate them all
# around a single context
- if asserter.accumulated and \
- asserter.accumulated[-1].context is context:
+ if (
+ asserter.accumulated
+ and asserter.accumulated[-1].context is context
+ ):
obs = asserter.accumulated[-1]
else:
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
asserter.accumulated.append(obs)
obs.statements.append(
SQLCursorExecuteObserved(
- statement, parameters, context, executemany)
+ statement, parameters, context, executemany
+ )
)
try: