diff options
Diffstat (limited to 'test/lib/testing.py')
-rw-r--r-- | test/lib/testing.py | 70 |
1 files changed, 35 insertions, 35 deletions
diff --git a/test/lib/testing.py b/test/lib/testing.py index 6a2c62959..d6338cf10 100644 --- a/test/lib/testing.py +++ b/test/lib/testing.py @@ -18,7 +18,7 @@ from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, \ from sqlalchemy.engine import default from nose import SkipTest - + _ops = { '<': operator.lt, '>': operator.gt, '==': operator.eq, @@ -90,9 +90,9 @@ def db_spec(*dbs): return engine.name in dialects or \ engine.driver in drivers or \ (engine.name, engine.driver) in specs - + return check - + def fails_on(dbs, reason): """Mark a test as expected to fail on the specified database @@ -105,7 +105,7 @@ def fails_on(dbs, reason): """ spec = db_spec(dbs) - + @decorator def decorate(fn, *args, **kw): if not spec(config.db): @@ -132,7 +132,7 @@ def fails_on_everything_except(*dbs): """ spec = db_spec(*dbs) - + @decorator def decorate(fn, *args, **kw): if spec(config.db): @@ -211,7 +211,7 @@ def only_on(dbs, reason): print >> sys.stderr, msg return True return decorate - + def exclude(db, op, spec, reason): """Mark a test as unsupported by specific database server versions. @@ -225,7 +225,7 @@ def exclude(db, op, spec, reason): """ carp = _should_carp_about_exclusion(reason) - + @decorator def decorate(fn, *args, **kw): if _is_excluded(db, op, spec): @@ -283,7 +283,7 @@ def _server_version(bind=None): if bind is None: bind = config.db - + # force metadata to be retrieved conn = bind.connect() version = getattr(bind.dialect, 'server_version_info', ()) @@ -294,7 +294,7 @@ def skip_if(predicate, reason=None): """Skip a test if predicate is true.""" reason = reason or predicate.__name__ carp = _should_carp_about_exclusion(reason) - + @decorator def decorate(fn, *args, **kw): if predicate(): @@ -320,7 +320,7 @@ def emits_warning(*messages): # and may work on non-CPython if they keep to the spirit of # warnings.showwarning's docstring. # - update: jython looks ok, it uses cpython's module - + @decorator def decorate(fn, *args, **kw): # todo: should probably be strict about this, too @@ -350,7 +350,7 @@ def emits_warning_on(db, *warnings): warnings.filterwarnings(). """ spec = db_spec(db) - + @decorator def decorate(fn, *args, **kw): if isinstance(db, basestring): @@ -369,16 +369,16 @@ def emits_warning_on(db, *warnings): def assert_warnings(fn, warnings): """Assert that each of the given warnings are emitted by fn.""" - + orig_warn = util.warn def capture_warnings(*args, **kw): orig_warn(*args, **kw) popwarn = warnings.pop(0) eq_(args[0], popwarn) util.warn = util.langhelpers.warn = capture_warnings - + return emits_warning()(fn)() - + def uses_deprecated(*messages): """Mark a test as immune from fatal deprecation warnings. @@ -419,7 +419,7 @@ def uses_deprecated(*messages): def testing_warn(msg, stacklevel=3): """Replaces sqlalchemy.util.warn during tests.""" - + filename = "test.lib.testing" lineno = 1 if isinstance(msg, basestring): @@ -429,9 +429,9 @@ def testing_warn(msg, stacklevel=3): def resetwarnings(): """Reset warning behavior to testing defaults.""" - + util.warn = util.langhelpers.warn = testing_warn - + warnings.filterwarnings('ignore', category=sa_exc.SAPendingDeprecationWarning) warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) @@ -440,17 +440,17 @@ def resetwarnings(): def global_cleanup_assertions(): """Check things that have to be finalized at the end of a test suite. - + Hardcoded at the moment, a modular system can be built here to support things like PG prepared transactions, tables all dropped, etc. - + """ testutil.lazy_gc() assert not pool._refs - - + + def against(*queries): """Boolean predicate, compares to testing database configuration. @@ -523,7 +523,7 @@ def assert_raises(except_cls, callable_, *args, **kw): success = False except except_cls, e: success = True - + # assert outside the block so it works for AssertionError too ! assert success, "Callable did not raise an exception" @@ -537,7 +537,7 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): def fail(msg): assert False, msg - + def fixture(table, columns, *rows): """Insert data into table after creation.""" def onload(event, schema_item, connection): @@ -624,34 +624,34 @@ class TestBase(object): def assert_(self, val, msg=None): assert val, msg - + class AssertsCompiledSQL(object): def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, use_default_dialect=False): if use_default_dialect: dialect = default.DefaultDialect() - + if dialect is None: dialect = getattr(self, '__dialect__', None) kw = {} if params is not None: kw['column_keys'] = params.keys() - + if isinstance(clause, orm.Query): context = clause._compile_context() context.statement.use_labels = True clause = context.statement - + c = clause.compile(dialect=dialect, **kw) param_str = repr(getattr(c, 'params', {})) # Py3K #param_str = param_str.encode('utf-8').decode('ascii', 'ignore') - + print "\nSQL String:\n" + str(c) + param_str - + cc = re.sub(r'[\n\t]', '', str(c)) - + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) if checkparams is not None: @@ -665,7 +665,7 @@ class ComparesTables(object): assert reflected_c is reflected_table.c[c.name] eq_(c.primary_key, reflected_c.primary_key) eq_(c.nullable, reflected_c.nullable) - + if strict_types: assert type(reflected_c.type) is type(c.type), \ "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) @@ -683,7 +683,7 @@ class ComparesTables(object): assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] is not None - + def assert_types_base(self, c1, c2): assert c1.type._compare_type_affinity(c2.type),\ "On column %r, type '%s' doesn't correspond to type '%s'" % \ @@ -770,13 +770,13 @@ class AssertsExecutionResults(object): assertsql.asserter.statement_complete() finally: assertsql.asserter.clear_rules() - + def assert_sql(self, db, callable_, list_, with_sequences=None): if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'): rules = with_sequences else: rules = list_ - + newrules = [] for rule in rules: if isinstance(rule, dict): @@ -786,7 +786,7 @@ class AssertsExecutionResults(object): else: newrule = assertsql.ExactSQL(*rule) newrules.append(newrule) - + self.assert_sql_execution(db, callable_, *newrules) def assert_sql_count(self, db, callable_, count): |