summaryrefslogtreecommitdiff
path: root/test/testlib/testing.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-12-10 02:16:52 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-12-10 02:16:52 +0000
commita2f90fd0038ab6586022762467ed95a84e62e224 (patch)
tree565cb80d9c3c5885f992c712890f26991c6cebc8 /test/testlib/testing.py
parent41f5f3465a444f36473593e179168841829fc1a5 (diff)
downloadsqlalchemy-a2f90fd0038ab6586022762467ed95a84e62e224.tar.gz
- reworked the "SQL assertion" code to something more flexible and based off of ConnectionProxy. upcoming changes to dependency.py
will make use of the enhanced flexibility.
Diffstat (limited to 'test/testlib/testing.py')
-rw-r--r--test/testlib/testing.py154
1 files changed, 25 insertions, 129 deletions
diff --git a/test/testlib/testing.py b/test/testlib/testing.py
index f6a16a4f8..5f5d323c7 100644
--- a/test/testlib/testing.py
+++ b/test/testlib/testing.py
@@ -13,6 +13,7 @@ from cStringIO import StringIO
import testlib.config as config
from testlib.compat import _function_named
+from testlib import assertsql
# Delayed imports
MetaData = None
@@ -459,110 +460,6 @@ def fixture(table, columns, *rows):
for column_values in rows])
table.append_ddl_listener('after-create', onload)
-class TestData(object):
- """Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
-
- def __init__(self):
- self.set_assert_list(None, None)
- self.sql_count = 0
- self.buffer = None
-
- def set_assert_list(self, unittest, list):
- self.unittest = unittest
- self.assert_list = list
- if list is not None:
- self.assert_list.reverse()
-
-testdata = TestData()
-
-
-class ExecutionContextWrapper(object):
- """instruments the ExecutionContext created by the Engine so that SQL expressions
- can be tracked."""
-
- def __init__(self, ctx):
- self.__dict__['ctx'] = ctx
- def __getattr__(self, key):
- return getattr(self.ctx, key)
- def __setattr__(self, key, value):
- setattr(self.ctx, key, value)
-
- trailing_underscore_pattern = re.compile(r'(\W:[\w_#]+)_\b',re.MULTILINE)
- def post_exec(self):
- ctx = self.ctx
- statement = unicode(ctx.compiled)
- statement = re.sub(r'\n', '', ctx.statement)
- if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'):
- statement = statement[:-25]
- if testdata.buffer is not None:
- testdata.buffer.write(statement + "\n")
-
- if testdata.assert_list is not None:
- assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement
- item = testdata.assert_list[-1]
- if not isinstance(item, dict):
- item = testdata.assert_list.pop()
- else:
- # asserting a dictionary of statements->parameters
- # this is to specify query assertions where the queries can be in
- # multiple orderings
- if '_converted' not in item:
- for key in item.keys():
- ckey = self.convert_statement(key)
- item[ckey] = item[key]
- if ckey != key:
- del item[key]
- item['_converted'] = True
- try:
- entry = item.pop(statement)
- if len(item) == 1:
- testdata.assert_list.pop()
- item = (statement, entry)
- except KeyError:
- assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)
-
- (query, params) = item
- if callable(params):
- params = params(ctx)
- if params is not None and not isinstance(params, list):
- params = [params]
-
- parameters = ctx.compiled_parameters
-
- query = self.convert_statement(query)
- equivalent = ( (statement == query)
- or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
- ) \
- and \
- ( (params is None) or (params == parameters)
- or params == [dict([(k.strip('_'), v)
- for (k, v) in p.items()])
- for p in parameters]
- )
- testdata.unittest.assert_(equivalent,
- "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
- testdata.sql_count += 1
- self.ctx.post_exec()
-
- def convert_statement(self, query):
- paramstyle = self.ctx.dialect.paramstyle
- if paramstyle == 'named':
- pass
- elif paramstyle =='pyformat':
- query = re.sub(r':([\w_]+)', r"%(\1)s", query)
- else:
- # positional params
- repl = None
- if paramstyle=='qmark':
- repl = "?"
- elif paramstyle=='format':
- repl = r"%s"
- elif paramstyle=='numeric':
- repl = None
- query = re.sub(r':([\w_]+)', repl, query)
- return query
-
-
def _import_by_name(name):
submodule = name.split('.')[-1]
return __import__(name, globals(), locals(), [submodule])
@@ -827,36 +724,35 @@ class AssertsExecutionResults(object):
cls.__name__, repr(expected_item)))
return True
- def assert_sql(self, db, callable_, list, with_sequences=None):
- global testdata
- testdata = TestData()
- if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
- testdata.set_assert_list(self, with_sequences)
- else:
- testdata.set_assert_list(self, list)
+ def assert_sql_execution(self, db, callable_, *rules):
+ assertsql.asserter.add_rules(rules)
try:
callable_()
+ assertsql.asserter.statement_complete()
finally:
- testdata.set_assert_list(None, None)
+ assertsql.asserter.clear_rules()
+
+ def assert_sql(self, db, callable_, list_, with_sequences=None):
+ if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
+ rules = with_sequences
+ else:
+ rules = list_
+
+ newrules = []
+ for rule in rules:
+ if isinstance(rule, dict):
+ newrule = assertsql.AllOf(*[
+ assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
+ ])
+ else:
+ newrule = assertsql.ExactSQL(*rule)
+ newrules.append(newrule)
+
+ self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
- global testdata
- testdata = TestData()
- callable_()
- self.assert_(testdata.sql_count == count,
- "desired statement count %d does not match %d" % (
- count, testdata.sql_count))
-
- def capture_sql(self, db, callable_):
- global testdata
- testdata = TestData()
- buffer = StringIO()
- testdata.buffer = buffer
- try:
- callable_()
- return buffer.getvalue()
- finally:
- testdata.buffer = None
+ self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
+
_otest_metadata = None
class ORMTest(TestBase, AssertsExecutionResults):