summaryrefslogtreecommitdiff
path: root/test/lib/assertsql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-11-15 19:37:50 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2010-11-15 19:37:50 -0500
commite1402efb198f96090833a9b561cdba8dee937f70 (patch)
treec20dc33466476bbf394e8e19da7a045c8a008d08 /test/lib/assertsql.py
parent756aa2724e495b8a969bca73d133b27615a343e7 (diff)
downloadsqlalchemy-e1402efb198f96090833a9b561cdba8dee937f70.tar.gz
- move sqlalchemy.test to test.lib
Diffstat (limited to 'test/lib/assertsql.py')
-rw-r--r--test/lib/assertsql.py295
1 files changed, 295 insertions, 0 deletions
diff --git a/test/lib/assertsql.py b/test/lib/assertsql.py
new file mode 100644
index 000000000..a044f9d02
--- /dev/null
+++ b/test/lib/assertsql.py
@@ -0,0 +1,295 @@
+
+from sqlalchemy.interfaces import ConnectionProxy
+from sqlalchemy.engine.default import DefaultDialect
+from sqlalchemy.engine.base import Connection
+from sqlalchemy import util
+import re
+
+class AssertRule(object):
+ def process_execute(self, clauseelement, *multiparams, **params):
+ pass
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ pass
+
+ def is_consumed(self):
+ """Return True if this rule has been consumed, False if not.
+
+ Should raise an AssertionError if this rule's condition has definitely failed.
+
+ """
+ raise NotImplementedError()
+
+ def rule_passed(self):
+ """Return True if the last test of this rule passed, False if failed, None if no test was applied."""
+
+ raise NotImplementedError()
+
+ def consume_final(self):
+ """Return True if this rule has been consumed.
+
+ Should raise an AssertionError if this rule's condition has not been consumed or has failed.
+
+ """
+
+ if self._result is None:
+ assert False, "Rule has not been consumed"
+
+ return self.is_consumed()
+
+class SQLMatchRule(AssertRule):
+ def __init__(self):
+ self._result = None
+ self._errmsg = ""
+
+ def rule_passed(self):
+ return self._result
+
+ def is_consumed(self):
+ if self._result is None:
+ return False
+
+ assert self._result, self._errmsg
+
+ return True
+
+class ExactSQL(SQLMatchRule):
+ def __init__(self, sql, params=None):
+ SQLMatchRule.__init__(self)
+ self.sql = sql
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ if not context:
+ return
+
+ _received_statement = _process_engine_statement(context.unicode_statement, context)
+ _received_parameters = context.compiled_parameters
+
+ # TODO: remove this step once all unit tests
+ # are migrated, as ExactSQL should really be *exact* SQL
+ sql = _process_assertion_statement(self.sql, context)
+
+ equivalent = _received_statement == sql
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+
+ if not isinstance(params, list):
+ params = [params]
+ equivalent = equivalent and params == context.compiled_parameters
+ else:
+ params = {}
+
+
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = "Testing for exact statement %r exact params %r, " \
+ "received %r with params %r" % (sql, params, _received_statement, _received_parameters)
+
+
+class RegexSQL(SQLMatchRule):
+ def __init__(self, regex, params=None):
+ SQLMatchRule.__init__(self)
+ self.regex = re.compile(regex)
+ self.orig_regex = regex
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ if not context:
+ return
+
+ _received_statement = _process_engine_statement(context.unicode_statement, context)
+ _received_parameters = context.compiled_parameters
+
+ equivalent = bool(self.regex.match(_received_statement))
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+
+ if not isinstance(params, list):
+ params = [params]
+
+ # do a positive compare only
+ for param, received in zip(params, _received_parameters):
+ for k, v in param.iteritems():
+ if k not in received or received[k] != v:
+ equivalent = False
+ break
+ else:
+ params = {}
+
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = "Testing for regex %r partial params %r, "\
+ "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters)
+
+class CompiledSQL(SQLMatchRule):
+ def __init__(self, statement, params):
+ SQLMatchRule.__init__(self)
+ self.statement = statement
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ if not context:
+ return
+
+ _received_parameters = list(context.compiled_parameters)
+
+ # recompile from the context, using the default dialect
+ compiled = context.compiled.statement.\
+ compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys)
+
+ _received_statement = re.sub(r'\n', '', str(compiled))
+
+ equivalent = self.statement == _received_statement
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+
+ if not isinstance(params, list):
+ params = [params]
+
+ all_params = list(params)
+ all_received = list(_received_parameters)
+ while params:
+ param = dict(params.pop(0))
+ for k, v in context.compiled.params.iteritems():
+ param.setdefault(k, v)
+
+ if param not in _received_parameters:
+ equivalent = False
+ break
+ else:
+ _received_parameters.remove(param)
+ if _received_parameters:
+ equivalent = False
+ else:
+ params = {}
+
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = "Testing for compiled statement %r partial params %r, " \
+ "received %r with params %r" % \
+ (self.statement, all_params, _received_statement, all_received)
+ #print self._errmsg
+
+
+class CountStatements(AssertRule):
+ def __init__(self, count):
+ self.count = count
+ self._statement_count = 0
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ self._statement_count += 1
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ pass
+
+ def is_consumed(self):
+ return False
+
+ def consume_final(self):
+ assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count)
+ return True
+
+class AllOf(AssertRule):
+ def __init__(self, *rules):
+ self.rules = set(rules)
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ for rule in self.rules:
+ rule.process_execute(clauseelement, *multiparams, **params)
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ for rule in self.rules:
+ rule.process_cursor_execute(statement, parameters, context, executemany)
+
+ def is_consumed(self):
+ if not self.rules:
+ return True
+
+ for rule in list(self.rules):
+ if rule.rule_passed(): # a rule passed, move on
+ self.rules.remove(rule)
+ return len(self.rules) == 0
+
+ assert False, "No assertion rules were satisfied for statement"
+
+ def consume_final(self):
+ return len(self.rules) == 0
+
+def _process_engine_statement(query, context):
+ if util.jython:
+ # oracle+zxjdbc passes a PyStatement when returning into
+ query = unicode(query)
+ if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
+ query = query[:-25]
+
+ query = re.sub(r'\n', '', query)
+
+ return query
+
+def _process_assertion_statement(query, context):
+ paramstyle = context.dialect.paramstyle
+ if paramstyle == 'named':
+ pass
+ elif paramstyle =='pyformat':
+ query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+ else:
+ # positional params
+ repl = None
+ if paramstyle=='qmark':
+ repl = "?"
+ elif paramstyle=='format':
+ repl = r"%s"
+ elif paramstyle=='numeric':
+ repl = None
+ query = re.sub(r':([\w_]+)', repl, query)
+
+ return query
+
+class SQLAssert(ConnectionProxy):
+ rules = None
+
+ def add_rules(self, rules):
+ self.rules = list(rules)
+
+ def statement_complete(self):
+ for rule in self.rules:
+ if not rule.consume_final():
+ assert False, "All statements are complete, but pending assertion rules remain"
+
+ def clear_rules(self):
+ del self.rules
+
+ def execute(self, conn, execute, clauseelement, *multiparams, **params):
+ result = execute(clauseelement, *multiparams, **params)
+
+ if self.rules is not None:
+ if not self.rules:
+ assert False, "All rules have been exhausted, but further statements remain"
+ rule = self.rules[0]
+ rule.process_execute(clauseelement, *multiparams, **params)
+ if rule.is_consumed():
+ self.rules.pop(0)
+
+ return result
+
+ def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+ result = execute(cursor, statement, parameters, context)
+
+ if self.rules:
+ rule = self.rules[0]
+ rule.process_cursor_execute(statement, parameters, context, executemany)
+
+ return result
+
+asserter = SQLAssert()
+