summaryrefslogtreecommitdiff
path: root/test/lib/testing.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/lib/testing.py')
-rw-r--r--test/lib/testing.py70
1 files changed, 35 insertions, 35 deletions
diff --git a/test/lib/testing.py b/test/lib/testing.py
index 6a2c62959..d6338cf10 100644
--- a/test/lib/testing.py
+++ b/test/lib/testing.py
@@ -18,7 +18,7 @@ from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, \
from sqlalchemy.engine import default
from nose import SkipTest
-
+
_ops = { '<': operator.lt,
'>': operator.gt,
'==': operator.eq,
@@ -90,9 +90,9 @@ def db_spec(*dbs):
return engine.name in dialects or \
engine.driver in drivers or \
(engine.name, engine.driver) in specs
-
+
return check
-
+
def fails_on(dbs, reason):
"""Mark a test as expected to fail on the specified database
@@ -105,7 +105,7 @@ def fails_on(dbs, reason):
"""
spec = db_spec(dbs)
-
+
@decorator
def decorate(fn, *args, **kw):
if not spec(config.db):
@@ -132,7 +132,7 @@ def fails_on_everything_except(*dbs):
"""
spec = db_spec(*dbs)
-
+
@decorator
def decorate(fn, *args, **kw):
if spec(config.db):
@@ -211,7 +211,7 @@ def only_on(dbs, reason):
print >> sys.stderr, msg
return True
return decorate
-
+
def exclude(db, op, spec, reason):
"""Mark a test as unsupported by specific database server versions.
@@ -225,7 +225,7 @@ def exclude(db, op, spec, reason):
"""
carp = _should_carp_about_exclusion(reason)
-
+
@decorator
def decorate(fn, *args, **kw):
if _is_excluded(db, op, spec):
@@ -283,7 +283,7 @@ def _server_version(bind=None):
if bind is None:
bind = config.db
-
+
# force metadata to be retrieved
conn = bind.connect()
version = getattr(bind.dialect, 'server_version_info', ())
@@ -294,7 +294,7 @@ def skip_if(predicate, reason=None):
"""Skip a test if predicate is true."""
reason = reason or predicate.__name__
carp = _should_carp_about_exclusion(reason)
-
+
@decorator
def decorate(fn, *args, **kw):
if predicate():
@@ -320,7 +320,7 @@ def emits_warning(*messages):
# and may work on non-CPython if they keep to the spirit of
# warnings.showwarning's docstring.
# - update: jython looks ok, it uses cpython's module
-
+
@decorator
def decorate(fn, *args, **kw):
# todo: should probably be strict about this, too
@@ -350,7 +350,7 @@ def emits_warning_on(db, *warnings):
warnings.filterwarnings().
"""
spec = db_spec(db)
-
+
@decorator
def decorate(fn, *args, **kw):
if isinstance(db, basestring):
@@ -369,16 +369,16 @@ def emits_warning_on(db, *warnings):
def assert_warnings(fn, warnings):
"""Assert that each of the given warnings are emitted by fn."""
-
+
orig_warn = util.warn
def capture_warnings(*args, **kw):
orig_warn(*args, **kw)
popwarn = warnings.pop(0)
eq_(args[0], popwarn)
util.warn = util.langhelpers.warn = capture_warnings
-
+
return emits_warning()(fn)()
-
+
def uses_deprecated(*messages):
"""Mark a test as immune from fatal deprecation warnings.
@@ -419,7 +419,7 @@ def uses_deprecated(*messages):
def testing_warn(msg, stacklevel=3):
"""Replaces sqlalchemy.util.warn during tests."""
-
+
filename = "test.lib.testing"
lineno = 1
if isinstance(msg, basestring):
@@ -429,9 +429,9 @@ def testing_warn(msg, stacklevel=3):
def resetwarnings():
"""Reset warning behavior to testing defaults."""
-
+
util.warn = util.langhelpers.warn = testing_warn
-
+
warnings.filterwarnings('ignore',
category=sa_exc.SAPendingDeprecationWarning)
warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
@@ -440,17 +440,17 @@ def resetwarnings():
def global_cleanup_assertions():
"""Check things that have to be finalized at the end of a test suite.
-
+
Hardcoded at the moment, a modular system can be built here
to support things like PG prepared transactions, tables all
dropped, etc.
-
+
"""
testutil.lazy_gc()
assert not pool._refs
-
-
+
+
def against(*queries):
"""Boolean predicate, compares to testing database configuration.
@@ -523,7 +523,7 @@ def assert_raises(except_cls, callable_, *args, **kw):
success = False
except except_cls, e:
success = True
-
+
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
@@ -537,7 +537,7 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
def fail(msg):
assert False, msg
-
+
def fixture(table, columns, *rows):
"""Insert data into table after creation."""
def onload(event, schema_item, connection):
@@ -624,34 +624,34 @@ class TestBase(object):
def assert_(self, val, msg=None):
assert val, msg
-
+
class AssertsCompiledSQL(object):
def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, use_default_dialect=False):
if use_default_dialect:
dialect = default.DefaultDialect()
-
+
if dialect is None:
dialect = getattr(self, '__dialect__', None)
kw = {}
if params is not None:
kw['column_keys'] = params.keys()
-
+
if isinstance(clause, orm.Query):
context = clause._compile_context()
context.statement.use_labels = True
clause = context.statement
-
+
c = clause.compile(dialect=dialect, **kw)
param_str = repr(getattr(c, 'params', {}))
# Py3K
#param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
-
+
print "\nSQL String:\n" + str(c) + param_str
-
+
cc = re.sub(r'[\n\t]', '', str(c))
-
+
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
if checkparams is not None:
@@ -665,7 +665,7 @@ class ComparesTables(object):
assert reflected_c is reflected_table.c[c.name]
eq_(c.primary_key, reflected_c.primary_key)
eq_(c.nullable, reflected_c.nullable)
-
+
if strict_types:
assert type(reflected_c.type) is type(c.type), \
"Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
@@ -683,7 +683,7 @@ class ComparesTables(object):
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
-
+
def assert_types_base(self, c1, c2):
assert c1.type._compare_type_affinity(c2.type),\
"On column %r, type '%s' doesn't correspond to type '%s'" % \
@@ -770,13 +770,13 @@ class AssertsExecutionResults(object):
assertsql.asserter.statement_complete()
finally:
assertsql.asserter.clear_rules()
-
+
def assert_sql(self, db, callable_, list_, with_sequences=None):
if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'):
rules = with_sequences
else:
rules = list_
-
+
newrules = []
for rule in rules:
if isinstance(rule, dict):
@@ -786,7 +786,7 @@ class AssertsExecutionResults(object):
else:
newrule = assertsql.ExactSQL(*rule)
newrules.append(newrule)
-
+
self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):