diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
commit | 8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch) | |
tree | ae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/test/testing.py | |
parent | 7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff) | |
download | sqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz |
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/test/testing.py')
-rw-r--r-- | lib/sqlalchemy/test/testing.py | 133 |
1 files changed, 88 insertions, 45 deletions
diff --git a/lib/sqlalchemy/test/testing.py b/lib/sqlalchemy/test/testing.py index 36c7d340a..16a13d9d3 100644 --- a/lib/sqlalchemy/test/testing.py +++ b/lib/sqlalchemy/test/testing.py @@ -8,10 +8,12 @@ import types import warnings from cStringIO import StringIO -from sqlalchemy.test import config, assertsql +from sqlalchemy.test import config, assertsql, util as testutil from sqlalchemy.util import function_named +from engines import drop_all_tables -from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema +from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool +from nose import SkipTest _ops = { '<': operator.lt, '>': operator.gt, @@ -80,6 +82,19 @@ def future(fn): "Unexpected success for future test '%s'" % fn_name) return function_named(decorated, fn_name) +def db_spec(*dbs): + dialects = set([x for x in dbs if '+' not in x]) + drivers = set([x[1:] for x in dbs if x.startswith('+')]) + specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers]) + + def check(engine): + return engine.name in dialects or \ + engine.driver in drivers or \ + (engine.name, engine.driver) in specs + + return check + + def fails_on(dbs, reason): """Mark a test as expected to fail on the specified database implementation. @@ -90,23 +105,25 @@ def fails_on(dbs, reason): succeeds, a failure is reported. """ + spec = db_spec(dbs) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name != dbs: + if not spec(config.db): return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected on DB implementation " - "'%s': %s" % ( - fn_name, config.db.name, reason)) + "'%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason)) return True else: raise AssertionError( - "Unexpected success for '%s' on DB implementation '%s'" % - (fn_name, config.db.name)) + "Unexpected success for '%s' on DB implementation '%s+%s'" % + (fn_name, config.db.name, config.db.driver)) return function_named(maybe, fn_name) return decorate @@ -117,23 +134,25 @@ def fails_on_everything_except(*dbs): databases except those listed. """ + spec = db_spec(*dbs) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name in dbs: + if spec(config.db): return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected on DB implementation " - "'%s': %s" % ( - fn_name, config.db.name, str(ex))) + "'%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, str(ex))) return True else: raise AssertionError( - "Unexpected success for '%s' on DB implementation '%s'" % - (fn_name, config.db.name)) + "Unexpected success for '%s' on DB implementation '%s+%s'" % + (fn_name, config.db.name, config.db.driver)) return function_named(maybe, fn_name) return decorate @@ -145,12 +164,13 @@ def crashes(db, reason): """ carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name == db: - msg = "'%s' unsupported on DB implementation '%s': %s" % ( - fn_name, config.db.name, reason) + if spec(config.db): + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) print msg if carp: print >> sys.stderr, msg @@ -169,12 +189,13 @@ def _block_unconditionally(db, reason): """ carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name == db: - msg = "'%s' unsupported on DB implementation '%s': %s" % ( - fn_name, config.db.name, reason) + if spec(config.db): + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) print msg if carp: print >> sys.stderr, msg @@ -198,6 +219,7 @@ def exclude(db, op, spec, reason): """ carp = _should_carp_about_exclusion(reason) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): @@ -242,7 +264,9 @@ def _is_excluded(db, op, spec): _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) """ - if config.db.name != db: + vendor_spec = db_spec(db) + + if not vendor_spec(config.db): return False version = _server_version() @@ -255,7 +279,12 @@ def _server_version(bind=None): if bind is None: bind = config.db - return bind.dialect.server_version_info(bind.contextual_connect()) + + # force metadata to be retrieved + conn = bind.connect() + version = getattr(bind.dialect, 'server_version_info', ()) + conn.close() + return version def skip_if(predicate, reason=None): """Skip a test if predicate is true.""" @@ -266,8 +295,7 @@ def skip_if(predicate, reason=None): if predicate(): msg = "'%s' skipped on DB %s version '%s': %s" % ( fn_name, config.db.name, _server_version(), reason) - print msg - return True + raise SkipTest(msg) else: return fn(*args, **kw) return function_named(maybe, fn_name) @@ -315,10 +343,12 @@ def emits_warning_on(db, *warnings): strings; these will be matched to the root of the warning description by warnings.filterwarnings(). """ + spec = db_spec(db) + def decorate(fn): def maybe(*args, **kw): if isinstance(db, basestring): - if config.db.name != db: + if not spec(config.db): return fn(*args, **kw) else: wrapped = emits_warning(*warnings)(fn) @@ -384,6 +414,19 @@ def resetwarnings(): if sys.version_info < (2, 4): warnings.filterwarnings('ignore', category=FutureWarning) +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + + testutil.lazy_gc() + assert not pool._refs + + def against(*queries): """Boolean predicate, compares to testing database configuration. @@ -394,21 +437,20 @@ def against(*queries): Also supports comparison to database version when provided with one or more 3-tuples of dialect name, operator, and version specification:: - testing.against('mysql', 'postgres') + testing.against('mysql', 'postgresql') testing.against(('mysql', '>=', (5, 0, 0)) """ for query in queries: if isinstance(query, basestring): - if config.db.name == query: + if db_spec(query)(config.db): return True else: name, op, spec = query - if config.db.name != name: + if not db_spec(name)(config.db): continue - have = config.db.dialect.server_version_info( - config.db.contextual_connect()) + have = _server_version() oper = hasattr(op, '__call__') and op or _ops[op] if oper(have, spec): @@ -545,16 +587,15 @@ class AssertsCompiledSQL(object): if dialect is None: dialect = getattr(self, '__dialect__', None) - if params is None: - keys = None - else: - keys = params.keys() + kw = {} + if params is not None: + kw['column_keys'] = params.keys() - c = clause.compile(column_keys=keys, dialect=dialect) + c = clause.compile(dialect=dialect, **kw) - print "\nSQL String:\n" + str(c) + repr(c.params) + print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {})) - cc = re.sub(r'\n', '', str(c)) + cc = re.sub(r'[\n\t]', '', str(c)) eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) @@ -563,18 +604,13 @@ class AssertsCompiledSQL(object): class ComparesTables(object): def assert_tables_equal(self, table, reflected_table): - base_mro = sqltypes.TypeEngine.__mro__ assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): eq_(c.name, reflected_c.name) assert reflected_c is reflected_table.c[c.name] eq_(c.primary_key, reflected_c.primary_key) eq_(c.nullable, reflected_c.nullable) - assert len( - set(type(reflected_c.type).__mro__).difference(base_mro).intersection( - set(type(c.type).__mro__).difference(base_mro) - ) - ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + self.assert_types_base(reflected_c, c) if isinstance(c.type, sqltypes.String): eq_(c.type.length, reflected_c.type.length) @@ -586,14 +622,21 @@ class ComparesTables(object): elif against(('mysql', '<', (5, 0))): # ignore reflection of bogus db-generated DefaultClause() pass - elif not c.primary_key or not against('postgres'): - print repr(c) + elif not c.primary_key or not against('postgresql', 'mssql'): + #print repr(c) assert reflected_c.default is None, reflected_c.default assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] - + + def assert_types_base(self, c1, c2): + base_mro = sqltypes.TypeEngine.__mro__ + assert len( + set(type(c1.type).__mro__).difference(base_mro).intersection( + set(type(c2.type).__mro__).difference(base_mro) + ) + ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (c1.name, c1.type, c2.type) class AssertsExecutionResults(object): def assert_result(self, result, class_, *objects): @@ -678,7 +721,7 @@ class AssertsExecutionResults(object): assertsql.asserter.clear_rules() def assert_sql(self, db, callable_, list_, with_sequences=None): - if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): + if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'): rules = with_sequences else: rules = list_ |