summaryrefslogtreecommitdiff
path: root/test/lib
diff options
context:
space:
mode:
Diffstat (limited to 'test/lib')
-rw-r--r--test/lib/assertsql.py30
-rw-r--r--test/lib/engines.py34
-rw-r--r--test/lib/entities.py2
-rw-r--r--test/lib/pickleable.py4
-rw-r--r--test/lib/profiling.py14
-rw-r--r--test/lib/requires.py32
-rw-r--r--test/lib/schema.py2
-rw-r--r--test/lib/testing.py70
-rw-r--r--test/lib/util.py30
9 files changed, 109 insertions, 109 deletions
diff --git a/test/lib/assertsql.py b/test/lib/assertsql.py
index b206f91fc..4954a0bfe 100644
--- a/test/lib/assertsql.py
+++ b/test/lib/assertsql.py
@@ -16,10 +16,10 @@ class AssertRule(object):
def is_consumed(self):
"""Return True if this rule has been consumed, False if not.
-
+
Should raise an AssertionError if this rule's condition has
definitely failed.
-
+
"""
raise NotImplementedError()
@@ -32,10 +32,10 @@ class AssertRule(object):
def consume_final(self):
"""Return True if this rule has been consumed.
-
+
Should raise an AssertionError if this rule's condition has not
been consumed or has failed.
-
+
"""
if self._result is None:
@@ -46,18 +46,18 @@ class SQLMatchRule(AssertRule):
def __init__(self):
self._result = None
self._errmsg = ""
-
+
def rule_passed(self):
return self._result
-
+
def is_consumed(self):
if self._result is None:
return False
-
+
assert self._result, self._errmsg
-
+
return True
-
+
class ExactSQL(SQLMatchRule):
def __init__(self, sql, params=None):
@@ -96,7 +96,7 @@ class ExactSQL(SQLMatchRule):
'Testing for exact statement %r exact params %r, '\
'received %r with params %r' % (sql, params,
_received_statement, _received_parameters)
-
+
class RegexSQL(SQLMatchRule):
@@ -194,7 +194,7 @@ class CompiledSQL(SQLMatchRule):
# print self._errmsg
-
+
class CountStatements(AssertRule):
def __init__(self, count):
@@ -216,7 +216,7 @@ class CountStatements(AssertRule):
'desired statement count %d does not match %d' \
% (self.count, self._statement_count)
return True
-
+
class AllOf(AssertRule):
def __init__(self, *rules):
@@ -243,7 +243,7 @@ class AllOf(AssertRule):
def consume_final(self):
return len(self.rules) == 0
-
+
def _process_engine_statement(query, context):
if util.jython:
@@ -255,7 +255,7 @@ def _process_engine_statement(query, context):
query = query[:-25]
query = re.sub(r'\n', '', query)
return query
-
+
def _process_assertion_statement(query, context):
paramstyle = context.dialect.paramstyle
if paramstyle == 'named':
@@ -311,4 +311,4 @@ class SQLAssert(object):
executemany)
asserter = SQLAssert()
-
+
diff --git a/test/lib/engines.py b/test/lib/engines.py
index 60272839b..1c97e23cf 100644
--- a/test/lib/engines.py
+++ b/test/lib/engines.py
@@ -65,11 +65,11 @@ def rollback_open_connections(fn, *args, **kw):
@decorator
def close_first(fn, *args, **kw):
"""Decorator that closes all connections before fn execution."""
-
+
testing_reaper.close_all()
fn(*args, **kw)
-
-
+
+
@decorator
def close_open_connections(fn, *args, **kw):
"""Decorator that closes all connections after fn execution."""
@@ -88,7 +88,7 @@ def all_dialects(exclude=None):
if not mod:
mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
yield mod.dialect()
-
+
class ReconnectFixture(object):
def __init__(self, dbapi):
self.dbapi = dbapi
@@ -135,11 +135,11 @@ def testing_engine(url=None, options=None):
event.listen(engine, 'after_execute', asserter.execute)
event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
event.listen(engine.pool, 'checkout', testing_reaper.checkout)
-
+
# may want to call this, results
# in first-connect initializers
#engine.connect()
-
+
return engine
def utf8_engine(url=None, options=None):
@@ -165,18 +165,18 @@ def utf8_engine(url=None, options=None):
def mock_engine(dialect_name=None):
"""Provides a mocking engine based on the current testing.db.
-
+
This is normally used to test DDL generation flow as emitted
by an Engine.
-
+
It should not be used in other cases, as assert_compile() and
assert_sql_execution() are much better choices with fewer
moving parts.
-
+
"""
-
+
from sqlalchemy import create_engine
-
+
if not dialect_name:
dialect_name = config.db.name
@@ -186,7 +186,7 @@ def mock_engine(dialect_name=None):
def assert_sql(stmts):
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
assert recv == stmts, recv
-
+
engine = create_engine(dialect_name + '://',
strategy='mock', executor=executor)
assert not hasattr(engine, 'mock')
@@ -212,7 +212,7 @@ class ReplayableSession(object):
#for t in ('FunctionType', 'BuiltinFunctionType',
# 'MethodType', 'BuiltinMethodType',
# 'LambdaType', )])
-
+
# Py2K
for t in ('FunctionType', 'BuiltinFunctionType',
'MethodType', 'BuiltinMethodType',
@@ -243,11 +243,11 @@ class ReplayableSession(object):
else:
buffer.append(result)
return result
-
+
@property
def _sqla_unwrap(self):
return self._subject
-
+
def __getattribute__(self, key):
try:
return object.__getattribute__(self, key)
@@ -280,11 +280,11 @@ class ReplayableSession(object):
return self
else:
return result
-
+
@property
def _sqla_unwrap(self):
return None
-
+
def __getattribute__(self, key):
try:
return object.__getattribute__(self, key)
diff --git a/test/lib/entities.py b/test/lib/entities.py
index 0ec677eea..1b24e73b7 100644
--- a/test/lib/entities.py
+++ b/test/lib/entities.py
@@ -50,7 +50,7 @@ class ComparableEntity(BasicEntity):
self_key = sa.orm.attributes.instance_state(self).key
except sa.orm.exc.NO_STATE:
self_key = None
-
+
if other is None:
a = self
b = other
diff --git a/test/lib/pickleable.py b/test/lib/pickleable.py
index 9794e424d..acc07ceba 100644
--- a/test/lib/pickleable.py
+++ b/test/lib/pickleable.py
@@ -33,11 +33,11 @@ class OldSchool:
def __eq__(self, other):
return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
-class OldSchoolWithoutCompare:
+class OldSchoolWithoutCompare:
def __init__(self, x, y):
self.x = x
self.y = y
-
+
class BarWithoutCompare(object):
def __init__(self, x, y):
self.x = x
diff --git a/test/lib/profiling.py b/test/lib/profiling.py
index 8676adf01..bac9e549f 100644
--- a/test/lib/profiling.py
+++ b/test/lib/profiling.py
@@ -41,7 +41,7 @@ def profiled(target=None, **target_opts):
all_targets.add(target)
filename = "%s.prof" % target
-
+
@decorator
def decorate(fn, *args, **kw):
if (target not in profile_config['targets'] and
@@ -50,7 +50,7 @@ def profiled(target=None, **target_opts):
elapsed, load_stats, result = _profile(
filename, fn, *args, **kw)
-
+
graphic = target_opts.get('graphic', profile_config['graphic'])
if graphic:
os.system("runsnake %s" % filename)
@@ -68,17 +68,17 @@ def profiled(target=None, **target_opts):
stats.print_stats(limit)
else:
stats.print_stats()
-
+
print_callers = target_opts.get('print_callers',
profile_config['print_callers'])
if print_callers:
stats.print_callers()
-
+
print_callees = target_opts.get('print_callees',
profile_config['print_callees'])
if print_callees:
stats.print_callees()
-
+
os.unlink(filename)
return result
return decorate
@@ -113,7 +113,7 @@ def function_call_count(count=None, versions={}, variance=0.05):
cextension = True
except ImportError:
cextension = False
-
+
while version_info:
version = '.'.join([str(v) for v in version_info])
if cextension and (version + "+cextension") in versions:
@@ -129,7 +129,7 @@ def function_call_count(count=None, versions={}, variance=0.05):
if count is None:
print "Warning: no function call count specified for version: '%s'" % py_version
return lambda fn: fn
-
+
@decorator
def decorate(fn, *args, **kw):
try:
diff --git a/test/lib/requires.py b/test/lib/requires.py
index 222dc93f6..993a1546f 100644
--- a/test/lib/requires.py
+++ b/test/lib/requires.py
@@ -54,7 +54,7 @@ def boolean_col_expressions(fn):
no_support('maxdb', 'FIXME: verify not supported by database'),
no_support('informix', 'not supported by database'),
)
-
+
def identity(fn):
"""Target database must support GENERATED AS IDENTITY or a facsimile.
@@ -99,19 +99,19 @@ def row_triggers(fn):
# no access to same table
no_support('mysql', 'requires SUPER priv'),
exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
-
+
# huh? TODO: implement triggers for PG tests, remove this
- no_support('postgresql', 'PG triggers need to be implemented for tests'),
+ no_support('postgresql', 'PG triggers need to be implemented for tests'),
)
def correlated_outer_joins(fn):
"""Target must support an outer join to a subquery which correlates to the parent."""
-
+
return _chain_decorators_on(
fn,
no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"')
)
-
+
def savepoints(fn):
"""Target database must support savepoints."""
return _chain_decorators_on(
@@ -126,21 +126,21 @@ def savepoints(fn):
def denormalized_names(fn):
"""Target database must have 'denormalized', i.e. UPPERCASE as case insensitive names."""
-
+
return skip_if(
lambda: not testing.db.dialect.requires_name_normalize,
"Backend does not require denomralized names."
)(fn)
-
+
def schemas(fn):
"""Target database must support external schemas, and have one named 'test_schema'."""
-
+
return _chain_decorators_on(
fn,
no_support('sqlite', 'no schema support'),
no_support('firebird', 'no schema support')
)
-
+
def sequences(fn):
"""Target database must support SEQUENCEs."""
return _chain_decorators_on(
@@ -164,7 +164,7 @@ def update_nowait(fn):
no_support('sqlite', 'no FOR UPDATE NOWAIT support'),
no_support('sybase', 'no FOR UPDATE NOWAIT support'),
)
-
+
def subqueries(fn):
"""Target database must support subqueries."""
return _chain_decorators_on(
@@ -198,7 +198,7 @@ def offset(fn):
fn,
fails_on('sybase', 'no support for OFFSET or equivalent'),
)
-
+
def returning(fn):
return _chain_decorators_on(
fn,
@@ -209,7 +209,7 @@ def returning(fn):
no_support('sybase', 'not supported by database'),
no_support('informix', 'not supported by database'),
)
-
+
def two_phase_transactions(fn):
"""Target database must support two-phase transactions."""
return _chain_decorators_on(
@@ -257,13 +257,13 @@ def cextensions(fn):
fn,
skip_if(lambda: not _has_cextensions(), "C extensions not installed")
)
-
+
def dbapi_lastrowid(fn):
return _chain_decorators_on(
fn,
fails_on_everything_except('mysql+mysqldb', 'mysql+oursql', 'sqlite+pysqlite')
)
-
+
def sane_multi_rowcount(fn):
return _chain_decorators_on(
fn,
@@ -283,7 +283,7 @@ def reflects_pk_names(fn):
fn,
fails_on_everything_except('postgresql', 'oracle')
)
-
+
def python2(fn):
return _chain_decorators_on(
fn,
@@ -317,7 +317,7 @@ def _has_cextensions():
return True
except ImportError:
return False
-
+
def _has_sqlite():
from sqlalchemy import create_engine
try:
diff --git a/test/lib/schema.py b/test/lib/schema.py
index 614e5863e..b4aabfe76 100644
--- a/test/lib/schema.py
+++ b/test/lib/schema.py
@@ -76,4 +76,4 @@ def _truncate_name(dialect, name):
return name[0:max(dialect.max_identifier_length - 6, 0)] + "_" + hex(hash(name) % 64)[2:]
else:
return name
-
+
diff --git a/test/lib/testing.py b/test/lib/testing.py
index 6a2c62959..d6338cf10 100644
--- a/test/lib/testing.py
+++ b/test/lib/testing.py
@@ -18,7 +18,7 @@ from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, \
from sqlalchemy.engine import default
from nose import SkipTest
-
+
_ops = { '<': operator.lt,
'>': operator.gt,
'==': operator.eq,
@@ -90,9 +90,9 @@ def db_spec(*dbs):
return engine.name in dialects or \
engine.driver in drivers or \
(engine.name, engine.driver) in specs
-
+
return check
-
+
def fails_on(dbs, reason):
"""Mark a test as expected to fail on the specified database
@@ -105,7 +105,7 @@ def fails_on(dbs, reason):
"""
spec = db_spec(dbs)
-
+
@decorator
def decorate(fn, *args, **kw):
if not spec(config.db):
@@ -132,7 +132,7 @@ def fails_on_everything_except(*dbs):
"""
spec = db_spec(*dbs)
-
+
@decorator
def decorate(fn, *args, **kw):
if spec(config.db):
@@ -211,7 +211,7 @@ def only_on(dbs, reason):
print >> sys.stderr, msg
return True
return decorate
-
+
def exclude(db, op, spec, reason):
"""Mark a test as unsupported by specific database server versions.
@@ -225,7 +225,7 @@ def exclude(db, op, spec, reason):
"""
carp = _should_carp_about_exclusion(reason)
-
+
@decorator
def decorate(fn, *args, **kw):
if _is_excluded(db, op, spec):
@@ -283,7 +283,7 @@ def _server_version(bind=None):
if bind is None:
bind = config.db
-
+
# force metadata to be retrieved
conn = bind.connect()
version = getattr(bind.dialect, 'server_version_info', ())
@@ -294,7 +294,7 @@ def skip_if(predicate, reason=None):
"""Skip a test if predicate is true."""
reason = reason or predicate.__name__
carp = _should_carp_about_exclusion(reason)
-
+
@decorator
def decorate(fn, *args, **kw):
if predicate():
@@ -320,7 +320,7 @@ def emits_warning(*messages):
# and may work on non-CPython if they keep to the spirit of
# warnings.showwarning's docstring.
# - update: jython looks ok, it uses cpython's module
-
+
@decorator
def decorate(fn, *args, **kw):
# todo: should probably be strict about this, too
@@ -350,7 +350,7 @@ def emits_warning_on(db, *warnings):
warnings.filterwarnings().
"""
spec = db_spec(db)
-
+
@decorator
def decorate(fn, *args, **kw):
if isinstance(db, basestring):
@@ -369,16 +369,16 @@ def emits_warning_on(db, *warnings):
def assert_warnings(fn, warnings):
"""Assert that each of the given warnings are emitted by fn."""
-
+
orig_warn = util.warn
def capture_warnings(*args, **kw):
orig_warn(*args, **kw)
popwarn = warnings.pop(0)
eq_(args[0], popwarn)
util.warn = util.langhelpers.warn = capture_warnings
-
+
return emits_warning()(fn)()
-
+
def uses_deprecated(*messages):
"""Mark a test as immune from fatal deprecation warnings.
@@ -419,7 +419,7 @@ def uses_deprecated(*messages):
def testing_warn(msg, stacklevel=3):
"""Replaces sqlalchemy.util.warn during tests."""
-
+
filename = "test.lib.testing"
lineno = 1
if isinstance(msg, basestring):
@@ -429,9 +429,9 @@ def testing_warn(msg, stacklevel=3):
def resetwarnings():
"""Reset warning behavior to testing defaults."""
-
+
util.warn = util.langhelpers.warn = testing_warn
-
+
warnings.filterwarnings('ignore',
category=sa_exc.SAPendingDeprecationWarning)
warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
@@ -440,17 +440,17 @@ def resetwarnings():
def global_cleanup_assertions():
"""Check things that have to be finalized at the end of a test suite.
-
+
Hardcoded at the moment, a modular system can be built here
to support things like PG prepared transactions, tables all
dropped, etc.
-
+
"""
testutil.lazy_gc()
assert not pool._refs
-
-
+
+
def against(*queries):
"""Boolean predicate, compares to testing database configuration.
@@ -523,7 +523,7 @@ def assert_raises(except_cls, callable_, *args, **kw):
success = False
except except_cls, e:
success = True
-
+
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
@@ -537,7 +537,7 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
def fail(msg):
assert False, msg
-
+
def fixture(table, columns, *rows):
"""Insert data into table after creation."""
def onload(event, schema_item, connection):
@@ -624,34 +624,34 @@ class TestBase(object):
def assert_(self, val, msg=None):
assert val, msg
-
+
class AssertsCompiledSQL(object):
def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, use_default_dialect=False):
if use_default_dialect:
dialect = default.DefaultDialect()
-
+
if dialect is None:
dialect = getattr(self, '__dialect__', None)
kw = {}
if params is not None:
kw['column_keys'] = params.keys()
-
+
if isinstance(clause, orm.Query):
context = clause._compile_context()
context.statement.use_labels = True
clause = context.statement
-
+
c = clause.compile(dialect=dialect, **kw)
param_str = repr(getattr(c, 'params', {}))
# Py3K
#param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
-
+
print "\nSQL String:\n" + str(c) + param_str
-
+
cc = re.sub(r'[\n\t]', '', str(c))
-
+
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
if checkparams is not None:
@@ -665,7 +665,7 @@ class ComparesTables(object):
assert reflected_c is reflected_table.c[c.name]
eq_(c.primary_key, reflected_c.primary_key)
eq_(c.nullable, reflected_c.nullable)
-
+
if strict_types:
assert type(reflected_c.type) is type(c.type), \
"Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
@@ -683,7 +683,7 @@ class ComparesTables(object):
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
-
+
def assert_types_base(self, c1, c2):
assert c1.type._compare_type_affinity(c2.type),\
"On column %r, type '%s' doesn't correspond to type '%s'" % \
@@ -770,13 +770,13 @@ class AssertsExecutionResults(object):
assertsql.asserter.statement_complete()
finally:
assertsql.asserter.clear_rules()
-
+
def assert_sql(self, db, callable_, list_, with_sequences=None):
if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'):
rules = with_sequences
else:
rules = list_
-
+
newrules = []
for rule in rules:
if isinstance(rule, dict):
@@ -786,7 +786,7 @@ class AssertsExecutionResults(object):
else:
newrule = assertsql.ExactSQL(*rule)
newrules.append(newrule)
-
+
self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
diff --git a/test/lib/util.py b/test/lib/util.py
index 4c9892852..b512cf6b0 100644
--- a/test/lib/util.py
+++ b/test/lib/util.py
@@ -13,7 +13,7 @@ if jython:
gc.collect()
gc.collect()
return 0
-
+
# "lazy" gc, for VM's that don't GC on refcount == 0
lazy_gc = gc_collect
@@ -34,48 +34,48 @@ def picklers():
# end Py2K
import pickle
picklers.add(pickle)
-
+
# yes, this thing needs this much testing
for pickle in picklers:
for protocol in -1, 0, 1, 2:
yield pickle.loads, lambda d:pickle.dumps(d, protocol)
-
-
+
+
def round_decimal(value, prec):
if isinstance(value, float):
return round(value, prec)
-
+
# can also use shift() here but that is 2.6 only
return (value * decimal.Decimal("1" + "0" * prec)).to_integral(decimal.ROUND_FLOOR) / \
pow(10, prec)
-
+
class RandomSet(set):
def __iter__(self):
l = list(set.__iter__(self))
random.shuffle(l)
return iter(l)
-
+
def pop(self):
index = random.randint(0, len(self) - 1)
item = list(set.__iter__(self))[index]
self.remove(item)
return item
-
+
def union(self, other):
return RandomSet(set.union(self, other))
-
+
def difference(self, other):
return RandomSet(set.difference(self, other))
-
+
def intersection(self, other):
return RandomSet(set.intersection(self, other))
-
+
def copy(self):
return RandomSet(self)
-
+
def conforms_partial_ordering(tuples, sorted_elements):
"""True if the given sorting conforms to the given partial ordering."""
-
+
deps = defaultdict(set)
for parent, child in tuples:
deps[parent].add(child)
@@ -101,7 +101,7 @@ def all_partial_orderings(tuples, elements):
if not subset.intersection(edges[elem]):
for sub_ordering in _all_orderings(subset):
yield [elem] + sub_ordering
-
+
return iter(_all_orderings(elements))
@@ -114,7 +114,7 @@ def function_named(fn, name):
This function should be phased out as much as possible
in favor of @decorator. Tests that "generate" many named tests
should be modernized.
-
+
"""
try:
fn.__name__ = name