diff options
Diffstat (limited to 'test/testlib/testing.py')
-rw-r--r-- | test/testlib/testing.py | 150 |
1 files changed, 108 insertions, 42 deletions
diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 95071a475..d5fb3b4e5 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -21,18 +21,32 @@ _ops = { '<': operator.lt, 'between': lambda val, pair: val >= pair[0] and val <= pair[1], } -def unsupported(*dbs): - """Mark a test as unsupported by one or more database implementations""" +def fails_on(*dbs): + """Mark a test as expected to fail on one or more database implementations. + + Unlike ``unsupported``, tests marked as ``fails_on`` will be run + for the named databases. The test is expected to fail and the unit test + logic is inverted: if the test fails, a success is reported. If the test + succeeds, a failure is reported. + """ def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name in dbs: - print "'%s' unsupported on DB implementation '%s'" % ( - fn_name, config.db.name) - return True - else: + if config.db.name not in dbs: return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected on DB implementation " + "'%s': %s" % ( + fn_name, config.db.name, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for '%s' on DB implementation '%s'" % + (fn_name, config.db.name)) try: maybe.__name__ = fn_name except: @@ -40,19 +54,17 @@ def unsupported(*dbs): return maybe return decorate -def fails_on(*dbs): - """Mark a test as expected to fail on one or more database implementations. +def fails_on_everything_except(*dbs): + """Mark a test as expected to fail on most database implementations. - Unlike ``unsupported``, tests marked as ``fails_on`` will be run - for the named databases. The test is expected to fail and the unit test - logic is inverted: if the test fails, a success is reported. If the test - succeeds, a failure is reported. + Like ``fails_on``, except failure is the expected outcome on all + databases except those listed. """ def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name not in dbs: + if config.db.name in dbs: return fn(*args, **kw) else: try: @@ -73,18 +85,22 @@ def fails_on(*dbs): return maybe return decorate -def supported(*dbs): - """Mark a test as supported by one or more database implementations""" +def unsupported(*dbs): + """Mark a test as unsupported by one or more database implementations. + + 'unsupported' tests will be skipped unconditionally. Useful for feature + tests that cause deadlocks or other fatal problems. + """ def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): if config.db.name in dbs: - return fn(*args, **kw) - else: print "'%s' unsupported on DB implementation '%s'" % ( fn_name, config.db.name) return True + else: + return fn(*args, **kw) try: maybe.__name__ = fn_name except: @@ -95,7 +111,7 @@ def supported(*dbs): def exclude(db, op, spec): """Mark a test as unsupported by specific database server versions. - Stackable, both with other excludes and supported/unsupported. Examples:: + Stackable, both with other excludes and other decorators. Examples:: # Not supported by mydb versions less than 1, 0 @exclude('mydb', '<', (1,0)) # Other operators work too @@ -106,17 +122,9 @@ def exclude(db, op, spec): def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name != db: - return fn(*args, **kw) - - have = config.db.dialect.server_version_info( - config.db.contextual_connect()) - - oper = hasattr(op, '__call__') and op or _ops[op] - - if oper(have, spec): + if _is_excluded(db, op, spec): print "'%s' unsupported on DB %s version '%s'" % ( - fn_name, config.db.name, have) + fn_name, config.db.name, _server_version()) return True else: return fn(*args, **kw) @@ -127,6 +135,41 @@ def exclude(db, op, spec): return maybe return decorate +def _is_excluded(db, op, spec): + """Return True if the configured db matches an exclusion specification. + + db: + A dialect name + op: + An operator or stringified operator, such as '==' + spec: + A value that will be compared to the dialect's server_version_info + using the supplied operator. + + Examples:: + # Not supported by mydb versions less than 1, 0 + _is_excluded('mydb', '<', (1,0)) + # Other operators work too + _is_excluded('bigdb', '==', (9,0,9)) + _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) + """ + + if config.db.name != db: + return False + + version = _server_version() + + oper = hasattr(op, '__call__') and op or _ops[op] + return oper(version, spec) + +def _server_version(bind=None): + """Return a server_version_info tuple.""" + + if bind is None: + bind = config.db + return bind.dialect.server_version_info(bind.contextual_connect()) + + def against(*queries): """Boolean predicate, compares to testing database configuration. @@ -262,6 +305,12 @@ class ExecutionContextWrapper(object): return query class PersistTest(unittest.TestCase): + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None def __init__(self, *args, **params): unittest.TestCase.__init__(self, *args, **params) @@ -426,7 +475,8 @@ class ORMTest(AssertMixin): _otest_metadata = MetaData(config.db) else: _otest_metadata = self.metadata - _otest_metadata.bind = config.db + if self.metadata.bind is None: + _otest_metadata.bind = config.db self.define_tables(_otest_metadata) _otest_metadata.create_all() self.insert_data() @@ -490,23 +540,39 @@ class TTestSuite(unittest.TestSuite): return self(result) def __call__(self, result): - try: - if self._initTest is not None: - self._initTest.setUpAll() - except: - # skip tests if global setup fails - ex = self.__exc_info() - for test in self._tests: - result.addError(test, ex) - return False + init = getattr(self, '_initTest', None) + if init is not None: + if (hasattr(init, '__unsupported_on__') and + config.db.name in init.__unsupported_on__): + print "'%s' unsupported on DB implementation '%s'" % ( + init.__class__.__name__, config.db.name) + return True + if (getattr(init, '__only_on__', None) not in (None,config.db.name)): + print "'%s' unsupported on DB implementation '%s'" % ( + init.__class__.__name__, config.db.name) + return True + for rule in getattr(init, '__excluded_on__', ()): + if _is_excluded(*rule): + print "'%s' unsupported on DB %s version %s" % ( + init.__class__.__name__, config.db.name, + _server_version()) + return True + try: + init.setUpAll() + except: + # skip tests if global setup fails + ex = self.__exc_info() + for test in self._tests: + result.addError(test, ex) + return False try: return self.do_run(result) finally: try: - if self._initTest is not None: - self._initTest.tearDownAll() + if init is not None: + init.tearDownAll() except: - result.addError(self._initTest, self.__exc_info()) + result.addError(init, self.__exc_info()) pass def __exc_info(self): |