diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 92 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/engines.py | 3 |
3 files changed, 75 insertions, 33 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index bf7c27a89..66d1f3cb0 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -405,13 +405,16 @@ class AssertsExecutionResults(object): cls.__name__, repr(expected_item))) return True + def sql_execution_asserter(self, db=None): + if db is None: + from . import db as db + + return assertsql.assert_engine(db) + def assert_sql_execution(self, db, callable_, *rules): - assertsql.asserter.add_rules(rules) - try: + with self.sql_execution_asserter(db) as asserter: callable_() - assertsql.asserter.statement_complete() - finally: - assertsql.asserter.clear_rules() + asserter.assert_(*rules) def assert_sql(self, db, callable_, list_, with_sequences=None): if (with_sequences is not None and diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index bcc999fe3..2ac0605a2 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -8,6 +8,9 @@ from ..engine.default import DefaultDialect from .. import util import re +import collections +import contextlib +from .. import event class AssertRule(object): @@ -321,39 +324,78 @@ def _process_assertion_statement(query, context): return query -class SQLAssert(object): +class SQLExecuteObserved( + collections.namedtuple( + "SQLExecuteObserved", ["clauseelement", "multiparams", "params"]) +): + def process(self, rules): + if rules is not None: + if not rules: + assert False, \ + 'All rules have been exhausted, but further '\ + 'statements remain' + rule = rules[0] + rule.process_execute( + self.clauseelement, *self.multiparams, **self.params) + if rule.is_consumed(): + rules.pop(0) - rules = None - def add_rules(self, rules): - self.rules = list(rules) +class SQLCursorExecuteObserved( + collections.namedtuple( + "SQLCursorExecuteObserved", + ["statement", "parameters", "context", "executemany"]) +): + def process(self, rules): + if rules: + rule = rules[0] + rule.process_cursor_execute( + self.statement, self.parameters, + self.context, self.executemany) - def statement_complete(self): - for rule in self.rules: + +class SQLAsserter(object): + def __init__(self): + self.accumulated = [] + + def _close(self): + # safety feature in case event.remove + # goes haywire + self._final = self.accumulated + del self.accumulated + + def assert_(self, *rules): + rules = list(rules) + for observed in self._final: + observed.process(rules) + + for rule in rules: if not rule.consume_final(): assert False, \ 'All statements are complete, but pending '\ 'assertion rules remain' - def clear_rules(self): - del self.rules - def execute(self, conn, clauseelement, multiparams, params, result): - if self.rules is not None: - if not self.rules: - assert False, \ - 'All rules have been exhausted, but further '\ - 'statements remain' - rule = self.rules[0] - rule.process_execute(clauseelement, *multiparams, **params) - if rule.is_consumed(): - self.rules.pop(0) +@contextlib.contextmanager +def assert_engine(engine): + asserter = SQLAsserter() - def cursor_execute(self, conn, cursor, statement, parameters, - context, executemany): - if self.rules: - rule = self.rules[0] - rule.process_cursor_execute(statement, parameters, context, - executemany) + @event.listens_for(engine, "after_execute") + def execute(conn, clauseelement, multiparams, params, result): + asserter.accumulated.append( + SQLExecuteObserved( + clauseelement, multiparams, params)) -asserter = SQLAssert() + @event.listens_for(engine, "after_cursor_execute") + def cursor_execute(conn, cursor, statement, parameters, + context, executemany): + asserter.accumulated.append( + SQLCursorExecuteObserved( + statement, parameters, context, executemany)) + + try: + yield asserter + finally: + asserter._close() + event.remove(engine, "after_cursor_execute", cursor_execute) + event.remove(engine, "after_execute", execute) diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 0f6f59401..7d73e7423 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -204,7 +204,6 @@ def testing_engine(url=None, options=None): """Produce an engine configured by --options with optional overrides.""" from sqlalchemy import create_engine - from .assertsql import asserter if not options: use_reaper = True @@ -219,8 +218,6 @@ def testing_engine(url=None, options=None): if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 engine.pool._max_overflow = 0 - event.listen(engine, 'after_execute', asserter.execute) - event.listen(engine, 'after_cursor_execute', asserter.cursor_execute) if use_reaper: event.listen(engine.pool, 'connect', testing_reaper.connect) event.listen(engine.pool, 'checkout', testing_reaper.checkout) |