summaryrefslogtreecommitdiff
path: root/test/lib/assertsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/lib/assertsql.py')
-rw-r--r--test/lib/assertsql.py316
1 files changed, 0 insertions, 316 deletions
diff --git a/test/lib/assertsql.py b/test/lib/assertsql.py
deleted file mode 100644
index 897f4b3b1..000000000
--- a/test/lib/assertsql.py
+++ /dev/null
@@ -1,316 +0,0 @@
-
-from sqlalchemy.interfaces import ConnectionProxy
-from sqlalchemy.engine.default import DefaultDialect
-from sqlalchemy.engine.base import Connection
-from sqlalchemy import util
-import re
-
-class AssertRule(object):
-
- def process_execute(self, clauseelement, *multiparams, **params):
- pass
-
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- pass
-
- def is_consumed(self):
- """Return True if this rule has been consumed, False if not.
-
- Should raise an AssertionError if this rule's condition has
- definitely failed.
-
- """
-
- raise NotImplementedError()
-
- def rule_passed(self):
- """Return True if the last test of this rule passed, False if
- failed, None if no test was applied."""
-
- raise NotImplementedError()
-
- def consume_final(self):
- """Return True if this rule has been consumed.
-
- Should raise an AssertionError if this rule's condition has not
- been consumed or has failed.
-
- """
-
- if self._result is None:
- assert False, 'Rule has not been consumed'
- return self.is_consumed()
-
-class SQLMatchRule(AssertRule):
- def __init__(self):
- self._result = None
- self._errmsg = ""
-
- def rule_passed(self):
- return self._result
-
- def is_consumed(self):
- if self._result is None:
- return False
-
- assert self._result, self._errmsg
-
- return True
-
-class ExactSQL(SQLMatchRule):
-
- def __init__(self, sql, params=None):
- SQLMatchRule.__init__(self)
- self.sql = sql
- self.params = params
-
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- if not context:
- return
- _received_statement = \
- _process_engine_statement(context.unicode_statement,
- context)
- _received_parameters = context.compiled_parameters
-
- # TODO: remove this step once all unit tests are migrated, as
- # ExactSQL should really be *exact* SQL
-
- sql = _process_assertion_statement(self.sql, context)
- equivalent = _received_statement == sql
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
- if not isinstance(params, list):
- params = [params]
- equivalent = equivalent and params \
- == context.compiled_parameters
- else:
- params = {}
- self._result = equivalent
- if not self._result:
- self._errmsg = \
- 'Testing for exact statement %r exact params %r, '\
- 'received %r with params %r' % (sql, params,
- _received_statement, _received_parameters)
-
-
-class RegexSQL(SQLMatchRule):
-
- def __init__(self, regex, params=None):
- SQLMatchRule.__init__(self)
- self.regex = re.compile(regex)
- self.orig_regex = regex
- self.params = params
-
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- if not context:
- return
- _received_statement = \
- _process_engine_statement(context.unicode_statement,
- context)
- _received_parameters = context.compiled_parameters
- equivalent = bool(self.regex.match(_received_statement))
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
- if not isinstance(params, list):
- params = [params]
-
- # do a positive compare only
-
- for param, received in zip(params, _received_parameters):
- for k, v in param.iteritems():
- if k not in received or received[k] != v:
- equivalent = False
- break
- else:
- params = {}
- self._result = equivalent
- if not self._result:
- self._errmsg = \
- 'Testing for regex %r partial params %r, received %r '\
- 'with params %r' % (self.orig_regex, params,
- _received_statement,
- _received_parameters)
-
-class CompiledSQL(SQLMatchRule):
-
- def __init__(self, statement, params):
- SQLMatchRule.__init__(self)
- self.statement = statement
- self.params = params
-
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- if not context:
- return
- _received_parameters = list(context.compiled_parameters)
-
- # recompile from the context, using the default dialect
-
- compiled = \
- context.compiled.statement.compile(dialect=DefaultDialect(),
- column_keys=context.compiled.column_keys)
- _received_statement = re.sub(r'\n', '', str(compiled))
- equivalent = self.statement == _received_statement
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
- if not isinstance(params, list):
- params = [params]
- all_params = list(params)
- all_received = list(_received_parameters)
- while params:
- param = dict(params.pop(0))
- for k, v in context.compiled.params.iteritems():
- param.setdefault(k, v)
- if param not in _received_parameters:
- equivalent = False
- break
- else:
- _received_parameters.remove(param)
- if _received_parameters:
- equivalent = False
- else:
- params = {}
- all_params = {}
- all_received = []
- self._result = equivalent
- if not self._result:
- print 'Testing for compiled statement %r partial params '\
- '%r, received %r with params %r' % (self.statement,
- all_params, _received_statement, all_received)
- self._errmsg = \
- 'Testing for compiled statement %r partial params %r, '\
- 'received %r with params %r' % (self.statement,
- all_params, _received_statement, all_received)
-
-
- # print self._errmsg
-
-class CountStatements(AssertRule):
-
- def __init__(self, count):
- self.count = count
- self._statement_count = 0
-
- def process_execute(self, clauseelement, *multiparams, **params):
- self._statement_count += 1
-
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- pass
-
- def is_consumed(self):
- return False
-
- def consume_final(self):
- assert self.count == self._statement_count, \
- 'desired statement count %d does not match %d' \
- % (self.count, self._statement_count)
- return True
-
-class AllOf(AssertRule):
-
- def __init__(self, *rules):
- self.rules = set(rules)
-
- def process_execute(self, clauseelement, *multiparams, **params):
- for rule in self.rules:
- rule.process_execute(clauseelement, *multiparams, **params)
-
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- for rule in self.rules:
- rule.process_cursor_execute(statement, parameters, context,
- executemany)
-
- def is_consumed(self):
- if not self.rules:
- return True
- for rule in list(self.rules):
- if rule.rule_passed(): # a rule passed, move on
- self.rules.remove(rule)
- return len(self.rules) == 0
- assert False, 'No assertion rules were satisfied for statement'
-
- def consume_final(self):
- return len(self.rules) == 0
-
-def _process_engine_statement(query, context):
- if util.jython:
-
- # oracle+zxjdbc passes a PyStatement when returning into
-
- query = unicode(query)
- if context.engine.name == 'mssql' \
- and query.endswith('; select scope_identity()'):
- query = query[:-25]
- query = re.sub(r'\n', '', query)
- return query
-
-def _process_assertion_statement(query, context):
- paramstyle = context.dialect.paramstyle
- if paramstyle == 'named':
- pass
- elif paramstyle =='pyformat':
- query = re.sub(r':([\w_]+)', r"%(\1)s", query)
- else:
- # positional params
- repl = None
- if paramstyle=='qmark':
- repl = "?"
- elif paramstyle=='format':
- repl = r"%s"
- elif paramstyle=='numeric':
- repl = None
- query = re.sub(r':([\w_]+)', repl, query)
-
- return query
-
-class SQLAssert(object):
-
- rules = None
-
- def add_rules(self, rules):
- self.rules = list(rules)
-
- def statement_complete(self):
- for rule in self.rules:
- if not rule.consume_final():
- assert False, \
- 'All statements are complete, but pending '\
- 'assertion rules remain'
-
- def clear_rules(self):
- del self.rules
-
- def execute(self, conn, clauseelement, multiparams, params, result):
- if self.rules is not None:
- if not self.rules:
- assert False, \
- 'All rules have been exhausted, but further '\
- 'statements remain'
- rule = self.rules[0]
- rule.process_execute(clauseelement, *multiparams, **params)
- if rule.is_consumed():
- self.rules.pop(0)
-
- def cursor_execute(self, conn, cursor, statement, parameters,
- context, executemany):
- if self.rules:
- rule = self.rules[0]
- rule.process_cursor_execute(statement, parameters, context,
- executemany)
-
-asserter = SQLAssert()
-