summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/__init__.py5
-rw-r--r--lib/sqlalchemy/testing/assertions.py27
-rw-r--r--lib/sqlalchemy/testing/assertsql.py9
-rw-r--r--lib/sqlalchemy/testing/config.py1
-rw-r--r--lib/sqlalchemy/testing/engines.py31
-rw-r--r--lib/sqlalchemy/testing/entities.py6
-rw-r--r--lib/sqlalchemy/testing/exclusions.py19
-rw-r--r--lib/sqlalchemy/testing/fixtures.py15
-rw-r--r--lib/sqlalchemy/testing/pickleable.py33
-rw-r--r--lib/sqlalchemy/testing/plugin/noseplugin.py17
-rw-r--r--lib/sqlalchemy/testing/profiling.py47
-rw-r--r--lib/sqlalchemy/testing/requirements.py7
-rw-r--r--lib/sqlalchemy/testing/runner.py1
-rw-r--r--lib/sqlalchemy/testing/schema.py3
-rw-r--r--lib/sqlalchemy/testing/suite/test_ddl.py4
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py5
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py4
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py16
-rw-r--r--lib/sqlalchemy/testing/suite/test_update_delete.py6
-rw-r--r--lib/sqlalchemy/testing/util.py10
-rw-r--r--lib/sqlalchemy/testing/warnings.py4
21 files changed, 205 insertions, 65 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index 15b3471aa..e571a5045 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -10,12 +10,11 @@ from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
eq_, ne_, is_, is_not_, startswith_, assert_raises, \
- assert_raises_message, AssertsCompiledSQL, ComparesTables, AssertsExecutionResults
+ assert_raises_message, AssertsCompiledSQL, ComparesTables, \
+ AssertsExecutionResults
from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
crashes = skip
from .config import db, requirements as requires
-
-
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index e74d13a97..ebd10b130 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -16,6 +16,7 @@ import itertools
from .util import fail
import contextlib
+
def emits_warning(*messages):
"""Mark a test as emitting a warning.
@@ -50,6 +51,7 @@ def emits_warning(*messages):
resetwarnings()
return decorate
+
def emits_warning_on(db, *warnings):
"""Mark a test as emitting a warning on a specific dialect.
@@ -115,7 +117,6 @@ def uses_deprecated(*messages):
return decorate
-
def global_cleanup_assertions():
"""Check things that have to be finalized at the end of a test suite.
@@ -129,28 +130,32 @@ def global_cleanup_assertions():
assert not pool._refs, str(pool._refs)
-
def eq_(a, b, msg=None):
"""Assert a == b, with repr messaging on failure."""
assert a == b, msg or "%r != %r" % (a, b)
+
def ne_(a, b, msg=None):
"""Assert a != b, with repr messaging on failure."""
assert a != b, msg or "%r == %r" % (a, b)
+
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
+
def is_not_(a, b, msg=None):
"""Assert a is not b, with repr messaging on failure."""
assert a is not b, msg or "%r is %r" % (a, b)
+
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
a, fragment)
+
def assert_raises(except_cls, callable_, *args, **kw):
try:
callable_(*args, **kw)
@@ -161,6 +166,7 @@ def assert_raises(except_cls, callable_, *args, **kw):
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
+
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
@@ -214,7 +220,9 @@ class AssertsCompiledSQL(object):
p = c.construct_params(params)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
+
class ComparesTables(object):
+
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
@@ -224,15 +232,19 @@ class ComparesTables(object):
eq_(c.nullable, reflected_c.nullable)
if strict_types:
+ msg = "Type '%s' doesn't correspond to type '%s'"
assert type(reflected_c.type) is type(c.type), \
- "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+ msg % (reflected_c.type, c.type)
else:
self.assert_types_base(reflected_c, c)
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
- eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
+ eq_(
+ set([f.column.name for f in c.foreign_keys]),
+ set([f.column.name for f in reflected_c.foreign_keys])
+ )
if c.server_default:
assert isinstance(reflected_c.server_default,
schema.FetchedValue)
@@ -246,6 +258,7 @@ class ComparesTables(object):
"On column %r, type '%s' doesn't correspond to type '%s'" % \
(c1.name, c1.type, c2.type)
+
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
result = list(result)
@@ -296,6 +309,7 @@ class AssertsExecutionResults(object):
len(found), len(expected)))
NOVALUE = object()
+
def _compare_item(obj, spec):
for key, value in spec.iteritems():
if isinstance(value, tuple):
@@ -347,7 +361,8 @@ class AssertsExecutionResults(object):
self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
- self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
+ self.assert_sql_execution(
+ db, callable_, assertsql.CountStatements(count))
@contextlib.contextmanager
def assert_execution(self, *rules):
@@ -359,4 +374,4 @@ class AssertsExecutionResults(object):
assertsql.asserter.clear_rules()
def assert_statement_count(self, count):
- return self.assert_execution(assertsql.CountStatements(count)) \ No newline at end of file
+ return self.assert_execution(assertsql.CountStatements(count))
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 08ee55d57..d955d1554 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -3,6 +3,7 @@ from ..engine.default import DefaultDialect
from .. import util
import re
+
class AssertRule(object):
def process_execute(self, clauseelement, *multiparams, **params):
@@ -40,6 +41,7 @@ class AssertRule(object):
assert False, 'Rule has not been consumed'
return self.is_consumed()
+
class SQLMatchRule(AssertRule):
def __init__(self):
self._result = None
@@ -56,6 +58,7 @@ class SQLMatchRule(AssertRule):
return True
+
class ExactSQL(SQLMatchRule):
def __init__(self, sql, params=None):
@@ -138,6 +141,7 @@ class RegexSQL(SQLMatchRule):
_received_statement,
_received_parameters)
+
class CompiledSQL(SQLMatchRule):
def __init__(self, statement, params):
@@ -217,6 +221,7 @@ class CountStatements(AssertRule):
% (self.count, self._statement_count)
return True
+
class AllOf(AssertRule):
def __init__(self, *rules):
@@ -244,6 +249,7 @@ class AllOf(AssertRule):
def consume_final(self):
return len(self.rules) == 0
+
def _process_engine_statement(query, context):
if util.jython:
@@ -256,6 +262,7 @@ def _process_engine_statement(query, context):
query = re.sub(r'\n', '', query)
return query
+
def _process_assertion_statement(query, context):
paramstyle = context.dialect.paramstyle
if paramstyle == 'named':
@@ -275,6 +282,7 @@ def _process_assertion_statement(query, context):
return query
+
class SQLAssert(object):
rules = None
@@ -311,4 +319,3 @@ class SQLAssert(object):
executemany)
asserter = SQLAssert()
-
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index 2945bd456..ae4f585e1 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -1,3 +1,2 @@
requirements = None
db = None
-
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index 9d15c5078..20bcf0317 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -9,7 +9,9 @@ from .. import event, pool
import re
import warnings
+
class ConnectionKiller(object):
+
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
self.testing_engines = weakref.WeakKeyDictionary()
@@ -83,12 +85,14 @@ class ConnectionKiller(object):
testing_reaper = ConnectionKiller()
+
def drop_all_tables(metadata, bind):
testing_reaper.close_all()
if hasattr(bind, 'close'):
bind.close()
metadata.drop_all(bind)
+
@decorator
def assert_conns_closed(fn, *args, **kw):
try:
@@ -96,6 +100,7 @@ def assert_conns_closed(fn, *args, **kw):
finally:
testing_reaper.assert_all_closed()
+
@decorator
def rollback_open_connections(fn, *args, **kw):
"""Decorator that rolls back all open connections after fn execution."""
@@ -105,6 +110,7 @@ def rollback_open_connections(fn, *args, **kw):
finally:
testing_reaper.rollback_all()
+
@decorator
def close_first(fn, *args, **kw):
"""Decorator that closes all connections before fn execution."""
@@ -121,6 +127,7 @@ def close_open_connections(fn, *args, **kw):
finally:
testing_reaper.close_all()
+
def all_dialects(exclude=None):
import sqlalchemy.databases as d
for name in d.__all__:
@@ -129,10 +136,13 @@ def all_dialects(exclude=None):
continue
mod = getattr(d, name, None)
if not mod:
- mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+ mod = getattr(__import__(
+ 'sqlalchemy.databases.%s' % name).databases, name)
yield mod.dialect()
+
class ReconnectFixture(object):
+
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
@@ -165,6 +175,7 @@ class ReconnectFixture(object):
self._safe(c.close)
self.connections = []
+
def reconnecting_engine(url=None, options=None):
url = url or config.db_url
dbapi = config.db.dialect.dbapi
@@ -173,9 +184,11 @@ def reconnecting_engine(url=None, options=None):
options['module'] = ReconnectFixture(dbapi)
engine = testing_engine(url, options)
_dispose = engine.dispose
+
def dispose():
engine.dialect.dbapi.shutdown()
_dispose()
+
engine.test_shutdown = engine.dialect.dbapi.shutdown
engine.dispose = dispose
return engine
@@ -209,6 +222,7 @@ def testing_engine(url=None, options=None):
return engine
+
def utf8_engine(url=None, options=None):
"""Hook for dialects or drivers that don't handle utf8 by default."""
@@ -226,6 +240,7 @@ def utf8_engine(url=None, options=None):
return testing_engine(url, options)
+
def mock_engine(dialect_name=None):
"""Provides a mocking engine based on the current testing.db.
@@ -244,17 +259,21 @@ def mock_engine(dialect_name=None):
dialect_name = config.db.name
buffer = []
+
def executor(sql, *a, **kw):
buffer.append(sql)
+
def assert_sql(stmts):
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
assert recv == stmts, recv
+
def print_sql():
d = engine.dialect
return "\n".join(
str(s.compile(dialect=d))
for s in engine.mock
)
+
engine = create_engine(dialect_name + '://',
strategy='mock', executor=executor)
assert not hasattr(engine, 'mock')
@@ -263,6 +282,7 @@ def mock_engine(dialect_name=None):
engine.print_sql = print_sql
return engine
+
class DBAPIProxyCursor(object):
"""Proxy a DBAPI cursor.
@@ -287,6 +307,7 @@ class DBAPIProxyCursor(object):
def __getattr__(self, key):
return getattr(self.cursor, key)
+
class DBAPIProxyConnection(object):
"""Proxy a DBAPI connection.
@@ -308,14 +329,17 @@ class DBAPIProxyConnection(object):
def __getattr__(self, key):
return getattr(self.conn, key)
-def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor):
+
+def proxying_engine(conn_cls=DBAPIProxyConnection,
+ cursor_cls=DBAPIProxyCursor):
"""Produce an engine that provides proxy hooks for
common methods.
"""
def mock_conn():
return conn_cls(config.db, cursor_cls)
- return testing_engine(options={'creator':mock_conn})
+ return testing_engine(options={'creator': mock_conn})
+
class ReplayableSession(object):
"""A simple record/playback tool.
@@ -427,4 +451,3 @@ class ReplayableSession(object):
raise AttributeError(key)
else:
return result
-
diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py
index 1b24e73b7..5c5e69154 100644
--- a/lib/sqlalchemy/testing/entities.py
+++ b/lib/sqlalchemy/testing/entities.py
@@ -2,7 +2,10 @@ import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
_repr_stack = set()
+
+
class BasicEntity(object):
+
def __init__(self, **kw):
for key, value in kw.iteritems():
setattr(self, key, value)
@@ -21,7 +24,10 @@ class BasicEntity(object):
_repr_stack.remove(id(self))
_recursion_stack = set()
+
+
class ComparableEntity(BasicEntity):
+
def __hash__(self):
return hash(self.__class__)
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
index 3c70ec8d9..f105c8b6a 100644
--- a/lib/sqlalchemy/testing/exclusions.py
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -61,6 +61,7 @@ class skip_if(object):
self._fails_on = skip_if(other, reason)
return self
+
class fails_if(skip_if):
def __call__(self, fn):
@decorator
@@ -69,14 +70,17 @@ class fails_if(skip_if):
return fn(*args, **kw)
return decorate(fn)
+
def only_if(predicate, reason=None):
predicate = _as_predicate(predicate)
return skip_if(NotPredicate(predicate), reason)
+
def succeeds_if(predicate, reason=None):
predicate = _as_predicate(predicate)
return fails_if(NotPredicate(predicate), reason)
+
class Predicate(object):
@classmethod
def as_predicate(cls, predicate):
@@ -93,6 +97,7 @@ class Predicate(object):
else:
assert False, "unknown predicate type: %s" % predicate
+
class BooleanPredicate(Predicate):
def __init__(self, value, description=None):
self.value = value
@@ -110,6 +115,7 @@ class BooleanPredicate(Predicate):
def __str__(self):
return self._as_string()
+
class SpecPredicate(Predicate):
def __init__(self, db, op=None, spec=None, description=None):
self.db = db
@@ -177,6 +183,7 @@ class SpecPredicate(Predicate):
def __str__(self):
return self._as_string()
+
class LambdaPredicate(Predicate):
def __init__(self, lambda_, description=None, args=None, kw=None):
self.lambda_ = lambda_
@@ -201,6 +208,7 @@ class LambdaPredicate(Predicate):
def __str__(self):
return self._as_string()
+
class NotPredicate(Predicate):
def __init__(self, predicate):
self.predicate = predicate
@@ -211,6 +219,7 @@ class NotPredicate(Predicate):
def __str__(self):
return self.predicate._as_string(True)
+
class OrPredicate(Predicate):
def __init__(self, predicates, description=None):
self.predicates = predicates
@@ -256,9 +265,11 @@ class OrPredicate(Predicate):
_as_predicate = Predicate.as_predicate
+
def _is_excluded(db, op, spec):
return SpecPredicate(db, op, spec)()
+
def _server_version(engine):
"""Return a server_version_info tuple."""
@@ -268,24 +279,30 @@ def _server_version(engine):
conn.close()
return version
+
def db_spec(*dbs):
return OrPredicate(
Predicate.as_predicate(db) for db in dbs
)
+
def open():
return skip_if(BooleanPredicate(False, "mark as execute"))
+
def closed():
return skip_if(BooleanPredicate(True, "marked as skip"))
+
@decorator
def future(fn, *args, **kw):
return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature")
+
def fails_on(db, reason=None):
return fails_if(SpecPredicate(db), reason)
+
def fails_on_everything_except(*dbs):
return succeeds_if(
OrPredicate([
@@ -293,9 +310,11 @@ def fails_on_everything_except(*dbs):
])
)
+
def skip(db, reason=None):
return skip_if(SpecPredicate(db), reason)
+
def only_on(dbs, reason=None):
return only_if(
OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)])
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index 1a1204898..5c587cb2f 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -7,6 +7,7 @@ import sys
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
+
class TestBase(object):
# A sequence of database names to always run, regardless of the
# constraints below.
@@ -29,6 +30,7 @@ class TestBase(object):
def assert_(self, val, msg=None):
assert val, msg
+
class TablesTest(TestBase):
# 'once', None
@@ -208,9 +210,11 @@ class _ORMTest(object):
sa.orm.session.Session.close_all()
sa.orm.clear_mappers()
+
class ORMTest(_ORMTest, TestBase):
pass
+
class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
run_setup_classes = 'once'
@@ -252,7 +256,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
cls.classes.clear()
_ORMTest.teardown_class()
-
@classmethod
def _setup_once_classes(cls):
if cls.run_setup_classes == 'once':
@@ -275,18 +278,21 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
"""
cls_registry = cls.classes
+
class FindFixture(type):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
return type.__init__(cls, classname, bases, dict_)
-
class _Base(object):
__metaclass__ = FindFixture
+
class Basic(BasicEntity, _Base):
pass
+
class Comparable(ComparableEntity, _Base):
pass
+
cls.Basic = Basic
cls.Comparable = Comparable
fn()
@@ -306,6 +312,7 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
def setup_mappers(cls):
pass
+
class DeclarativeMappedTest(MappedTest):
run_setup_classes = 'once'
run_setup_mappers = 'once'
@@ -317,17 +324,21 @@ class DeclarativeMappedTest(MappedTest):
@classmethod
def _with_register_classes(cls, fn):
cls_registry = cls.classes
+
class FindFixtureDeclarative(DeclarativeMeta):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
return DeclarativeMeta.__init__(
cls, classname, bases, dict_)
+
class DeclarativeBasic(object):
__table_cls__ = schema.Table
+
_DeclBase = declarative_base(metadata=cls.metadata,
metaclass=FindFixtureDeclarative,
cls=DeclarativeBasic)
cls.DeclarativeBasic = _DeclBase
fn()
+
if cls.metadata.tables:
cls.metadata.create_all(config.db)
diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py
index f5b8b827c..09d51b5fa 100644
--- a/lib/sqlalchemy/testing/pickleable.py
+++ b/lib/sqlalchemy/testing/pickleable.py
@@ -1,43 +1,59 @@
-"""Classes used in pickling tests, need to be at the module level for unpickling."""
+"""Classes used in pickling tests, need to be at the module level for
+unpickling.
+"""
from . import fixtures
+
class User(fixtures.ComparableEntity):
pass
+
class Order(fixtures.ComparableEntity):
pass
+
class Dingaling(fixtures.ComparableEntity):
pass
+
class EmailUser(User):
pass
+
class Address(fixtures.ComparableEntity):
pass
+
# TODO: these are kind of arbitrary....
class Child1(fixtures.ComparableEntity):
pass
+
class Child2(fixtures.ComparableEntity):
pass
+
class Parent(fixtures.ComparableEntity):
pass
+
class Screen(object):
+
def __init__(self, obj, parent=None):
self.obj = obj
self.parent = parent
+
class Foo(object):
+
def __init__(self, moredata):
self.data = 'im data'
self.stuff = 'im stuff'
self.moredata = moredata
+
__hash__ = object.__hash__
+
def __eq__(self, other):
return other.data == self.data and \
other.stuff == self.stuff and \
@@ -45,40 +61,53 @@ class Foo(object):
class Bar(object):
+
def __init__(self, x, y):
self.x = x
self.y = y
+
__hash__ = object.__hash__
+
def __eq__(self, other):
return other.__class__ is self.__class__ and \
other.x == self.x and \
other.y == self.y
+
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
+
class OldSchool:
+
def __init__(self, x, y):
self.x = x
self.y = y
+
def __eq__(self, other):
return other.__class__ is self.__class__ and \
other.x == self.x and \
other.y == self.y
+
class OldSchoolWithoutCompare:
+
def __init__(self, x, y):
self.x = x
self.y = y
+
class BarWithoutCompare(object):
+
def __init__(self, x, y):
self.x = x
self.y = y
+
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
class NotComparable(object):
+
def __init__(self, data):
self.data = data
@@ -93,6 +122,7 @@ class NotComparable(object):
class BrokenComparable(object):
+
def __init__(self, data):
self.data = data
@@ -104,4 +134,3 @@ class BrokenComparable(object):
def __ne__(self, other):
raise NotImplementedError
-
diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py
index 37f7b29f5..c104c4614 100644
--- a/lib/sqlalchemy/testing/plugin/noseplugin.py
+++ b/lib/sqlalchemy/testing/plugin/noseplugin.py
@@ -41,6 +41,7 @@ db_opts = {}
options = None
_existing_engine = None
+
def _log(option, opt_str, value, parser):
global logging
if not logging:
@@ -59,34 +60,42 @@ def _list_dbs(*args):
print "%20s\t%s" % (macro, file_config.get('db', macro))
sys.exit(0)
+
def _server_side_cursors(options, opt_str, value, parser):
db_opts['server_side_cursors'] = True
+
def _engine_strategy(options, opt_str, value, parser):
if value:
db_opts['strategy'] = value
pre_configure = []
post_configure = []
+
+
def pre(fn):
pre_configure.append(fn)
return fn
+
+
def post(fn):
post_configure.append(fn)
return fn
+
@pre
def _setup_options(opt, file_config):
global options
options = opt
+
@pre
def _monkeypatch_cdecimal(options, file_config):
if options.cdecimal:
- import sys
import cdecimal
sys.modules['decimal'] = cdecimal
+
@post
def _engine_uri(options, file_config):
global db_label, db_url
@@ -105,6 +114,7 @@ def _engine_uri(options, file_config):
% db_label)
db_url = file_config.get('db', db_label)
+
@post
def _require(options, file_config):
if not(options.require or
@@ -131,12 +141,14 @@ def _require(options, file_config):
continue
pkg_resources.require(requirement)
+
@post
def _engine_pool(options, file_config):
if options.mockpool:
from sqlalchemy import pool
db_opts['poolclass'] = pool.AssertionPool
+
@post
def _create_testing_engine(options, file_config):
from sqlalchemy.testing import engines, config
@@ -199,6 +211,7 @@ def _set_table_options(options, file_config):
if options.mysql_engine:
table_options['mysql_engine'] = options.mysql_engine
+
@post
def _reverse_topological(options, file_config):
if options.reversetop:
@@ -208,6 +221,7 @@ def _reverse_topological(options, file_config):
topological.set = unitofwork.set = session.set = mapper.set = \
dependency.set = RandomSet
+
@post
def _requirements(options, file_config):
from sqlalchemy.testing import config
@@ -230,6 +244,7 @@ def _post_setup_options(opt, file_config):
from sqlalchemy.testing import config
config.options = options
+
@post
def _setup_profiling(options, file_config):
from sqlalchemy.testing import profiling
diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py
index a22e83cbc..ae9d176b7 100644
--- a/lib/sqlalchemy/testing/profiling.py
+++ b/lib/sqlalchemy/testing/profiling.py
@@ -22,6 +22,7 @@ from ..util.compat import jython, pypy, win32
_current_test = None
+
def profiled(target=None, **target_opts):
"""Function profiling.
@@ -69,13 +70,13 @@ def profiled(target=None, **target_opts):
else:
stats.print_stats()
- print_callers = target_opts.get('print_callers',
- profile_config['print_callers'])
+ print_callers = target_opts.get(
+ 'print_callers', profile_config['print_callers'])
if print_callers:
stats.print_callers()
- print_callees = target_opts.get('print_callees',
- profile_config['print_callees'])
+ print_callees = target_opts.get(
+ 'print_callees', profile_config['print_callees'])
if print_callees:
stats.print_callees()
@@ -92,10 +93,14 @@ class ProfileStatsFile(object):
"""
def __init__(self, filename):
- self.write = config.options is not None and config.options.write_profiles
+ self.write = (
+ config.options is not None and
+ config.options.write_profiles
+ )
self.fname = os.path.abspath(filename)
self.short_fname = os.path.split(self.fname)[-1]
- self.data = collections.defaultdict(lambda: collections.defaultdict(dict))
+ self.data = collections.defaultdict(
+ lambda: collections.defaultdict(dict))
self._read()
if self.write:
# rewrite for the case where features changed,
@@ -124,7 +129,10 @@ class ProfileStatsFile(object):
def has_stats(self):
test_key = _current_test
- return test_key in self.data and self.platform_key in self.data[test_key]
+ return (
+ test_key in self.data and
+ self.platform_key in self.data[test_key]
+ )
def result(self, callcount):
test_key = _current_test
@@ -153,7 +161,6 @@ class ProfileStatsFile(object):
per_platform['current_count'] += 1
return result
-
def _header(self):
return \
"# %s\n"\
@@ -165,8 +172,8 @@ class ProfileStatsFile(object):
"# assertions are raised if the counts do not match.\n"\
"# \n"\
"# To add a new callcount test, apply the function_call_count \n"\
- "# decorator and re-run the tests using the --write-profiles option - \n"\
- "# this file will be rewritten including the new count.\n"\
+ "# decorator and re-run the tests using the --write-profiles \n"\
+ "# option - this file will be rewritten including the new count.\n"\
"# \n"\
"" % (self.fname)
@@ -183,7 +190,8 @@ class ProfileStatsFile(object):
test_key, platform_key, counts = line.split()
per_fn = self.data[test_key]
per_platform = per_fn[platform_key]
- per_platform['counts'] = [int(count) for count in counts.split(",")]
+ c = [int(count) for count in counts.split(",")]
+ per_platform['counts'] = c
per_platform['lineno'] = lineno + 1
per_platform['current_count'] = 0
profile_f.close()
@@ -198,16 +206,13 @@ class ProfileStatsFile(object):
profile_f.write("\n# TEST: %s\n\n" % test_key)
for platform_key in sorted(per_fn):
per_platform = per_fn[platform_key]
- profile_f.write(
- "%s %s %s\n" % (
- test_key,
- platform_key, ",".join(str(count) for count in per_platform['counts'])
- )
- )
+ c = ",".join(str(count) for count in per_platform['counts'])
+ profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
profile_f.close()
from sqlalchemy.util.compat import update_wrapper
+
def function_call_count(variance=0.05):
"""Assert a target for a test case's function call count.
@@ -222,7 +227,6 @@ def function_call_count(variance=0.05):
def decorate(fn):
def wrap(*args, **kw):
-
if cProfile is None:
raise SkipTest("cProfile is not installed")
@@ -237,7 +241,6 @@ def function_call_count(variance=0.05):
gc_collect()
-
timespent, load_stats, fn_result = _profile(
fn, *args, **kw
)
@@ -263,8 +266,9 @@ def function_call_count(variance=0.05):
if abs(callcount - expected_count) > deviance:
raise AssertionError(
"Adjusted function call count %s not within %s%% "
- "of expected %s. (Delete line %d of file %s to regenerate "
- "this callcount, when tests are run with --write-profiles.)"
+ "of expected %s. (Delete line %d of file %s to "
+ "regenerate this callcount, when tests are run "
+ "with --write-profiles.)"
% (
callcount, (variance * 100),
expected_count, line_no,
@@ -288,4 +292,3 @@ def _profile(fn, *args, **kw):
ended = time.time()
return ended - began, load_stats, locals()['result']
-
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index d58538db9..68659a855 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -10,6 +10,7 @@ to provide specific inclusion/exlusions.
from . import exclusions
+
class Requirements(object):
def __init__(self, db, config):
self.db = db
@@ -178,7 +179,6 @@ class SuiteRequirements(Requirements):
"""
return exclusions.open()
-
@property
def datetime(self):
"""target dialect supports representation of Python
@@ -237,8 +237,10 @@ class SuiteRequirements(Requirements):
@property
def empty_strings_varchar(self):
- """target database can persist/return an empty string with a varchar."""
+ """target database can persist/return an empty string with a
+ varchar.
+ """
return exclusions.open()
@property
@@ -248,7 +250,6 @@ class SuiteRequirements(Requirements):
return exclusions.open()
-
@property
def update_from(self):
"""Target must support UPDATE..FROM syntax"""
diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py
index 1a4ba5212..6ec73d7c8 100644
--- a/lib/sqlalchemy/testing/runner.py
+++ b/lib/sqlalchemy/testing/runner.py
@@ -28,5 +28,6 @@ from sqlalchemy.testing.plugin.noseplugin import NoseSQLAlchemy
import nose
+
def main():
nose.main(addplugins=[NoseSQLAlchemy()])
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
index 5dfdc0e07..ad233ec22 100644
--- a/lib/sqlalchemy/testing/schema.py
+++ b/lib/sqlalchemy/testing/schema.py
@@ -7,6 +7,7 @@ __all__ = 'Table', 'Column',
table_options = {}
+
def Table(*args, **kw):
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
@@ -76,10 +77,10 @@ def Column(*args, **kw):
event.listen(col, 'after_parent_attach', add_seq, propagate=True)
return col
+
def _truncate_name(dialect, name):
if len(name) > dialect.max_identifier_length:
return name[0:max(dialect.max_identifier_length - 6, 0)] + \
"_" + hex(hash(name) % 64)[2:]
else:
return name
-
diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py
index 466429aa5..c5b162413 100644
--- a/lib/sqlalchemy/testing/suite/test_ddl.py
+++ b/lib/sqlalchemy/testing/suite/test_ddl.py
@@ -25,7 +25,6 @@ class TableDDLTest(fixtures.TestBase):
(1, 'some data')
)
-
@requirements.create_table
@util.provide_metadata
def test_create_table(self):
@@ -35,7 +34,6 @@ class TableDDLTest(fixtures.TestBase):
)
self._simple_roundtrip()
-
@requirements.drop_table
@util.provide_metadata
def test_drop_table(self):
@@ -48,4 +46,4 @@ class TableDDLTest(fixtures.TestBase):
)
-__all__ = ('TableDDLTest', ) \ No newline at end of file
+__all__ = ('TableDDLTest', )
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
index 3cd7d39bc..b2b2a0aa8 100644
--- a/lib/sqlalchemy/testing/suite/test_insert.py
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -8,6 +8,7 @@ from sqlalchemy import Integer, String, select, util
from ..schema import Table, Column
+
class LastrowidTest(fixtures.TablesTest):
run_deletes = 'each'
@@ -88,7 +89,6 @@ class InsertBehaviorTest(fixtures.TablesTest):
else:
engine = config.db
-
r = engine.execute(
self.tables.autoinc_pk.insert(),
data="some data"
@@ -107,6 +107,7 @@ class InsertBehaviorTest(fixtures.TablesTest):
assert r.is_insert
assert not r.returns_rows
+
class ReturningTest(fixtures.TablesTest):
run_deletes = 'each'
__requires__ = 'returning', 'autoincrement_insert'
@@ -162,5 +163,3 @@ class ReturningTest(fixtures.TablesTest):
__all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest')
-
-
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index a7c814db7..b9894347a 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -18,6 +18,7 @@ from sqlalchemy import event
metadata, users = None, None
+
class HasTableTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
@@ -31,6 +32,7 @@ class HasTableTest(fixtures.TablesTest):
assert config.db.dialect.has_table(conn, "test_table")
assert not config.db.dialect.has_table(conn, "nonexistent_table")
+
class HasSequenceTest(fixtures.TestBase):
__requires__ = 'sequences',
@@ -425,4 +427,4 @@ class ComponentReflectionTest(fixtures.TablesTest):
self._test_get_table_oid('users', schema='test_schema')
-__all__ = ('ComponentReflectionTest', 'HasSequenceTest', 'HasTableTest') \ No newline at end of file
+__all__ = ('ComponentReflectionTest', 'HasSequenceTest', 'HasTableTest')
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index 74cb52c6e..8d0500d71 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -8,6 +8,7 @@ from sqlalchemy import Date, DateTime, Time, MetaData, String
from ..schema import Table, Column
import datetime
+
class _UnicodeFixture(object):
__requires__ = 'unicode_data',
@@ -70,7 +71,6 @@ class _UnicodeFixture(object):
for row in rows:
assert isinstance(row[0], unicode)
-
def _test_empty_strings(self):
unicode_table = self.tables.unicode_table
@@ -83,16 +83,17 @@ class _UnicodeFixture(object):
).first()
eq_(row, (u'',))
+
class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
__requires__ = 'unicode_data',
datatype = Unicode(255)
-
@requirements.empty_strings_varchar
def test_empty_strings_varchar(self):
self._test_empty_strings()
+
class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
__requires__ = 'unicode_data', 'text_type'
@@ -114,6 +115,7 @@ class StringTest(fixtures.TestBase):
foo.create(config.db)
foo.drop(config.db)
+
class _DateFixture(object):
compare = None
@@ -165,37 +167,44 @@ class DateTimeTest(_DateFixture, fixtures.TablesTest):
datatype = DateTime
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
+
class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'datetime_microseconds',
datatype = DateTime
data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
class TimeTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'time',
datatype = Time
data = datetime.time(12, 57, 18)
+
class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'time_microseconds',
datatype = Time
data = datetime.time(12, 57, 18, 396)
+
class DateTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'date',
datatype = Date
data = datetime.date(2012, 10, 15)
+
class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'date',
datatype = Date
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
compare = datetime.date(2012, 10, 15)
+
class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'datetime_historic',
datatype = DateTime
data = datetime.datetime(1850, 11, 10, 11, 52, 35)
+
class DateHistoricTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'date_historic',
datatype = Date
@@ -207,6 +216,3 @@ __all__ = ('UnicodeVarcharTest', 'UnicodeTextTest',
'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest',
'TimeMicrosecondsTest', 'TimeTest', 'DateTimeMicrosecondsTest',
'DateHistoricTest', 'StringTest')
-
-
-
diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py
index e73b05485..a3456ac2a 100644
--- a/lib/sqlalchemy/testing/suite/test_update_delete.py
+++ b/lib/sqlalchemy/testing/suite/test_update_delete.py
@@ -1,9 +1,7 @@
from .. import fixtures, config
-from ..config import requirements
from ..assertions import eq_
-from .. import engines
-from sqlalchemy import Integer, String, select
+from sqlalchemy import Integer, String
from ..schema import Table, Column
@@ -61,4 +59,4 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest):
]
)
-__all__ = ('SimpleUpdateDeleteTest', ) \ No newline at end of file
+__all__ = ('SimpleUpdateDeleteTest', )
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
index 41b0a30b3..2592c341e 100644
--- a/lib/sqlalchemy/testing/util.py
+++ b/lib/sqlalchemy/testing/util.py
@@ -26,9 +26,11 @@ elif pypy:
else:
# assume CPython - straight gc.collect, lazy_gc() is a pass
gc_collect = gc.collect
+
def lazy_gc():
pass
+
def picklers():
picklers = set()
# Py2K
@@ -56,6 +58,7 @@ def round_decimal(value, prec):
).to_integral(decimal.ROUND_FLOOR) / \
pow(10, prec)
+
class RandomSet(set):
def __iter__(self):
l = list(set.__iter__(self))
@@ -80,6 +83,7 @@ class RandomSet(set):
def copy(self):
return RandomSet(self)
+
def conforms_partial_ordering(tuples, sorted_elements):
"""True if the given sorting conforms to the given partial ordering."""
@@ -93,6 +97,7 @@ def conforms_partial_ordering(tuples, sorted_elements):
else:
return True
+
def all_partial_orderings(tuples, elements):
edges = defaultdict(set)
for parent, child in tuples:
@@ -131,7 +136,6 @@ def function_named(fn, name):
return fn
-
def run_as_contextmanager(ctx, fn, *arg, **kw):
"""Run the given function under the given contextmanager,
simulating the behavior of 'with' to support older
@@ -152,6 +156,7 @@ def run_as_contextmanager(ctx, fn, *arg, **kw):
else:
return raise_
+
def rowset(results):
"""Converts the results of sql execution into a plain set of column tuples.
@@ -182,6 +187,7 @@ def provide_metadata(fn, *args, **kw):
metadata.drop_all()
self.metadata = prev_meta
+
class adict(dict):
"""Dict keys available as attributes. Shadows."""
def __getattribute__(self, key):
@@ -192,5 +198,3 @@ class adict(dict):
def get_all(self, *keys):
return tuple([self[key] for key in keys])
-
-
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
index 7afcc63c5..41f3dbfed 100644
--- a/lib/sqlalchemy/testing/warnings.py
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -4,6 +4,7 @@ import warnings
from .. import exc as sa_exc
from .. import util
+
def testing_warn(msg, stacklevel=3):
"""Replaces sqlalchemy.util.warn during tests."""
@@ -14,6 +15,7 @@ def testing_warn(msg, stacklevel=3):
else:
warnings.warn_explicit(msg, filename, lineno)
+
def resetwarnings():
"""Reset warning behavior to testing defaults."""
@@ -24,6 +26,7 @@ def resetwarnings():
warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
warnings.filterwarnings('error', category=sa_exc.SAWarning)
+
def assert_warnings(fn, warnings):
"""Assert that each of the given warnings are emitted by fn."""
@@ -31,6 +34,7 @@ def assert_warnings(fn, warnings):
canary = []
orig_warn = util.warn
+
def capture_warnings(*args, **kw):
orig_warn(*args, **kw)
popwarn = warnings.pop(0)