summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/test/testing.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-08-06 21:11:27 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-08-06 21:11:27 +0000
commit8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch)
treeae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/test/testing.py
parent7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff)
downloadsqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/test/testing.py')
-rw-r--r--lib/sqlalchemy/test/testing.py133
1 files changed, 88 insertions, 45 deletions
diff --git a/lib/sqlalchemy/test/testing.py b/lib/sqlalchemy/test/testing.py
index 36c7d340a..16a13d9d3 100644
--- a/lib/sqlalchemy/test/testing.py
+++ b/lib/sqlalchemy/test/testing.py
@@ -8,10 +8,12 @@ import types
import warnings
from cStringIO import StringIO
-from sqlalchemy.test import config, assertsql
+from sqlalchemy.test import config, assertsql, util as testutil
from sqlalchemy.util import function_named
+from engines import drop_all_tables
-from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema
+from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool
+from nose import SkipTest
_ops = { '<': operator.lt,
'>': operator.gt,
@@ -80,6 +82,19 @@ def future(fn):
"Unexpected success for future test '%s'" % fn_name)
return function_named(decorated, fn_name)
+def db_spec(*dbs):
+ dialects = set([x for x in dbs if '+' not in x])
+ drivers = set([x[1:] for x in dbs if x.startswith('+')])
+ specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers])
+
+ def check(engine):
+ return engine.name in dialects or \
+ engine.driver in drivers or \
+ (engine.name, engine.driver) in specs
+
+ return check
+
+
def fails_on(dbs, reason):
"""Mark a test as expected to fail on the specified database
implementation.
@@ -90,23 +105,25 @@ def fails_on(dbs, reason):
succeeds, a failure is reported.
"""
+ spec = db_spec(dbs)
+
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name != dbs:
+ if not spec(config.db):
return fn(*args, **kw)
else:
try:
fn(*args, **kw)
except Exception, ex:
print ("'%s' failed as expected on DB implementation "
- "'%s': %s" % (
- fn_name, config.db.name, reason))
+ "'%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, reason))
return True
else:
raise AssertionError(
- "Unexpected success for '%s' on DB implementation '%s'" %
- (fn_name, config.db.name))
+ "Unexpected success for '%s' on DB implementation '%s+%s'" %
+ (fn_name, config.db.name, config.db.driver))
return function_named(maybe, fn_name)
return decorate
@@ -117,23 +134,25 @@ def fails_on_everything_except(*dbs):
databases except those listed.
"""
+ spec = db_spec(*dbs)
+
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name in dbs:
+ if spec(config.db):
return fn(*args, **kw)
else:
try:
fn(*args, **kw)
except Exception, ex:
print ("'%s' failed as expected on DB implementation "
- "'%s': %s" % (
- fn_name, config.db.name, str(ex)))
+ "'%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, str(ex)))
return True
else:
raise AssertionError(
- "Unexpected success for '%s' on DB implementation '%s'" %
- (fn_name, config.db.name))
+ "Unexpected success for '%s' on DB implementation '%s+%s'" %
+ (fn_name, config.db.name, config.db.driver))
return function_named(maybe, fn_name)
return decorate
@@ -145,12 +164,13 @@ def crashes(db, reason):
"""
carp = _should_carp_about_exclusion(reason)
+ spec = db_spec(db)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name == db:
- msg = "'%s' unsupported on DB implementation '%s': %s" % (
- fn_name, config.db.name, reason)
+ if spec(config.db):
+ msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, reason)
print msg
if carp:
print >> sys.stderr, msg
@@ -169,12 +189,13 @@ def _block_unconditionally(db, reason):
"""
carp = _should_carp_about_exclusion(reason)
+ spec = db_spec(db)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name == db:
- msg = "'%s' unsupported on DB implementation '%s': %s" % (
- fn_name, config.db.name, reason)
+ if spec(config.db):
+ msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, reason)
print msg
if carp:
print >> sys.stderr, msg
@@ -198,6 +219,7 @@ def exclude(db, op, spec, reason):
"""
carp = _should_carp_about_exclusion(reason)
+
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
@@ -242,7 +264,9 @@ def _is_excluded(db, op, spec):
_is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
"""
- if config.db.name != db:
+ vendor_spec = db_spec(db)
+
+ if not vendor_spec(config.db):
return False
version = _server_version()
@@ -255,7 +279,12 @@ def _server_version(bind=None):
if bind is None:
bind = config.db
- return bind.dialect.server_version_info(bind.contextual_connect())
+
+ # force metadata to be retrieved
+ conn = bind.connect()
+ version = getattr(bind.dialect, 'server_version_info', ())
+ conn.close()
+ return version
def skip_if(predicate, reason=None):
"""Skip a test if predicate is true."""
@@ -266,8 +295,7 @@ def skip_if(predicate, reason=None):
if predicate():
msg = "'%s' skipped on DB %s version '%s': %s" % (
fn_name, config.db.name, _server_version(), reason)
- print msg
- return True
+ raise SkipTest(msg)
else:
return fn(*args, **kw)
return function_named(maybe, fn_name)
@@ -315,10 +343,12 @@ def emits_warning_on(db, *warnings):
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
"""
+ spec = db_spec(db)
+
def decorate(fn):
def maybe(*args, **kw):
if isinstance(db, basestring):
- if config.db.name != db:
+ if not spec(config.db):
return fn(*args, **kw)
else:
wrapped = emits_warning(*warnings)(fn)
@@ -384,6 +414,19 @@ def resetwarnings():
if sys.version_info < (2, 4):
warnings.filterwarnings('ignore', category=FutureWarning)
+def global_cleanup_assertions():
+ """Check things that have to be finalized at the end of a test suite.
+
+ Hardcoded at the moment, a modular system can be built here
+ to support things like PG prepared transactions, tables all
+ dropped, etc.
+
+ """
+
+ testutil.lazy_gc()
+ assert not pool._refs
+
+
def against(*queries):
"""Boolean predicate, compares to testing database configuration.
@@ -394,21 +437,20 @@ def against(*queries):
Also supports comparison to database version when provided with one or
more 3-tuples of dialect name, operator, and version specification::
- testing.against('mysql', 'postgres')
+ testing.against('mysql', 'postgresql')
testing.against(('mysql', '>=', (5, 0, 0))
"""
for query in queries:
if isinstance(query, basestring):
- if config.db.name == query:
+ if db_spec(query)(config.db):
return True
else:
name, op, spec = query
- if config.db.name != name:
+ if not db_spec(name)(config.db):
continue
- have = config.db.dialect.server_version_info(
- config.db.contextual_connect())
+ have = _server_version()
oper = hasattr(op, '__call__') and op or _ops[op]
if oper(have, spec):
@@ -545,16 +587,15 @@ class AssertsCompiledSQL(object):
if dialect is None:
dialect = getattr(self, '__dialect__', None)
- if params is None:
- keys = None
- else:
- keys = params.keys()
+ kw = {}
+ if params is not None:
+ kw['column_keys'] = params.keys()
- c = clause.compile(column_keys=keys, dialect=dialect)
+ c = clause.compile(dialect=dialect, **kw)
- print "\nSQL String:\n" + str(c) + repr(c.params)
+ print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {}))
- cc = re.sub(r'\n', '', str(c))
+ cc = re.sub(r'[\n\t]', '', str(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
@@ -563,18 +604,13 @@ class AssertsCompiledSQL(object):
class ComparesTables(object):
def assert_tables_equal(self, table, reflected_table):
- base_mro = sqltypes.TypeEngine.__mro__
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
eq_(c.primary_key, reflected_c.primary_key)
eq_(c.nullable, reflected_c.nullable)
- assert len(
- set(type(reflected_c.type).__mro__).difference(base_mro).intersection(
- set(type(c.type).__mro__).difference(base_mro)
- )
- ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+ self.assert_types_base(reflected_c, c)
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
@@ -586,14 +622,21 @@ class ComparesTables(object):
elif against(('mysql', '<', (5, 0))):
# ignore reflection of bogus db-generated DefaultClause()
pass
- elif not c.primary_key or not against('postgres'):
- print repr(c)
+ elif not c.primary_key or not against('postgresql', 'mssql'):
+ #print repr(c)
assert reflected_c.default is None, reflected_c.default
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name]
-
+
+ def assert_types_base(self, c1, c2):
+ base_mro = sqltypes.TypeEngine.__mro__
+ assert len(
+ set(type(c1.type).__mro__).difference(base_mro).intersection(
+ set(type(c2.type).__mro__).difference(base_mro)
+ )
+ ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (c1.name, c1.type, c2.type)
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
@@ -678,7 +721,7 @@ class AssertsExecutionResults(object):
assertsql.asserter.clear_rules()
def assert_sql(self, db, callable_, list_, with_sequences=None):
- if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
+ if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'):
rules = with_sequences
else:
rules = list_