diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-12-10 02:16:52 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-12-10 02:16:52 +0000 |
commit | a2f90fd0038ab6586022762467ed95a84e62e224 (patch) | |
tree | 565cb80d9c3c5885f992c712890f26991c6cebc8 /test/testlib/testing.py | |
parent | 41f5f3465a444f36473593e179168841829fc1a5 (diff) | |
download | sqlalchemy-a2f90fd0038ab6586022762467ed95a84e62e224.tar.gz |
- reworked the "SQL assertion" code to something more flexible and based off of ConnectionProxy. upcoming changes to dependency.py
will make use of the enhanced flexibility.
Diffstat (limited to 'test/testlib/testing.py')
-rw-r--r-- | test/testlib/testing.py | 154 |
1 files changed, 25 insertions, 129 deletions
diff --git a/test/testlib/testing.py b/test/testlib/testing.py index f6a16a4f8..5f5d323c7 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -13,6 +13,7 @@ from cStringIO import StringIO import testlib.config as config from testlib.compat import _function_named +from testlib import assertsql # Delayed imports MetaData = None @@ -459,110 +460,6 @@ def fixture(table, columns, *rows): for column_values in rows]) table.append_ddl_listener('after-create', onload) -class TestData(object): - """Tracks SQL expressions as they are executed via an instrumented ExecutionContext.""" - - def __init__(self): - self.set_assert_list(None, None) - self.sql_count = 0 - self.buffer = None - - def set_assert_list(self, unittest, list): - self.unittest = unittest - self.assert_list = list - if list is not None: - self.assert_list.reverse() - -testdata = TestData() - - -class ExecutionContextWrapper(object): - """instruments the ExecutionContext created by the Engine so that SQL expressions - can be tracked.""" - - def __init__(self, ctx): - self.__dict__['ctx'] = ctx - def __getattr__(self, key): - return getattr(self.ctx, key) - def __setattr__(self, key, value): - setattr(self.ctx, key, value) - - trailing_underscore_pattern = re.compile(r'(\W:[\w_#]+)_\b',re.MULTILINE) - def post_exec(self): - ctx = self.ctx - statement = unicode(ctx.compiled) - statement = re.sub(r'\n', '', ctx.statement) - if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'): - statement = statement[:-25] - if testdata.buffer is not None: - testdata.buffer.write(statement + "\n") - - if testdata.assert_list is not None: - assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement - item = testdata.assert_list[-1] - if not isinstance(item, dict): - item = testdata.assert_list.pop() - else: - # asserting a dictionary of statements->parameters - # this is to specify query assertions where the queries can be in - # multiple orderings - if '_converted' not in item: - for key in item.keys(): - ckey = self.convert_statement(key) - item[ckey] = item[key] - if ckey != key: - del item[key] - item['_converted'] = True - try: - entry = item.pop(statement) - if len(item) == 1: - testdata.assert_list.pop() - item = (statement, entry) - except KeyError: - assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement) - - (query, params) = item - if callable(params): - params = params(ctx) - if params is not None and not isinstance(params, list): - params = [params] - - parameters = ctx.compiled_parameters - - query = self.convert_statement(query) - equivalent = ( (statement == query) - or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) ) - ) \ - and \ - ( (params is None) or (params == parameters) - or params == [dict([(k.strip('_'), v) - for (k, v) in p.items()]) - for p in parameters] - ) - testdata.unittest.assert_(equivalent, - "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) - testdata.sql_count += 1 - self.ctx.post_exec() - - def convert_statement(self, query): - paramstyle = self.ctx.dialect.paramstyle - if paramstyle == 'named': - pass - elif paramstyle =='pyformat': - query = re.sub(r':([\w_]+)', r"%(\1)s", query) - else: - # positional params - repl = None - if paramstyle=='qmark': - repl = "?" - elif paramstyle=='format': - repl = r"%s" - elif paramstyle=='numeric': - repl = None - query = re.sub(r':([\w_]+)', repl, query) - return query - - def _import_by_name(name): submodule = name.split('.')[-1] return __import__(name, globals(), locals(), [submodule]) @@ -827,36 +724,35 @@ class AssertsExecutionResults(object): cls.__name__, repr(expected_item))) return True - def assert_sql(self, db, callable_, list, with_sequences=None): - global testdata - testdata = TestData() - if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): - testdata.set_assert_list(self, with_sequences) - else: - testdata.set_assert_list(self, list) + def assert_sql_execution(self, db, callable_, *rules): + assertsql.asserter.add_rules(rules) try: callable_() + assertsql.asserter.statement_complete() finally: - testdata.set_assert_list(None, None) + assertsql.asserter.clear_rules() + + def assert_sql(self, db, callable_, list_, with_sequences=None): + if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): + rules = with_sequences + else: + rules = list_ + + newrules = [] + for rule in rules: + if isinstance(rule, dict): + newrule = assertsql.AllOf(*[ + assertsql.ExactSQL(k, v) for k, v in rule.iteritems() + ]) + else: + newrule = assertsql.ExactSQL(*rule) + newrules.append(newrule) + + self.assert_sql_execution(db, callable_, *newrules) def assert_sql_count(self, db, callable_, count): - global testdata - testdata = TestData() - callable_() - self.assert_(testdata.sql_count == count, - "desired statement count %d does not match %d" % ( - count, testdata.sql_count)) - - def capture_sql(self, db, callable_): - global testdata - testdata = TestData() - buffer = StringIO() - testdata.buffer = buffer - try: - callable_() - return buffer.getvalue() - finally: - testdata.buffer = None + self.assert_sql_execution(db, callable_, assertsql.CountStatements(count)) + _otest_metadata = None class ORMTest(TestBase, AssertsExecutionResults): |