diff options
Diffstat (limited to 'lib/sqlalchemy/testing')
21 files changed, 205 insertions, 65 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 15b3471aa..e571a5045 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -10,12 +10,11 @@ from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\ from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ eq_, ne_, is_, is_not_, startswith_, assert_raises, \ - assert_raises_message, AssertsCompiledSQL, ComparesTables, AssertsExecutionResults + assert_raises_message, AssertsCompiledSQL, ComparesTables, \ + AssertsExecutionResults from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict crashes = skip from .config import db, requirements as requires - - diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index e74d13a97..ebd10b130 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -16,6 +16,7 @@ import itertools from .util import fail import contextlib + def emits_warning(*messages): """Mark a test as emitting a warning. @@ -50,6 +51,7 @@ def emits_warning(*messages): resetwarnings() return decorate + def emits_warning_on(db, *warnings): """Mark a test as emitting a warning on a specific dialect. @@ -115,7 +117,6 @@ def uses_deprecated(*messages): return decorate - def global_cleanup_assertions(): """Check things that have to be finalized at the end of a test suite. @@ -129,28 +130,32 @@ def global_cleanup_assertions(): assert not pool._refs, str(pool._refs) - def eq_(a, b, msg=None): """Assert a == b, with repr messaging on failure.""" assert a == b, msg or "%r != %r" % (a, b) + def ne_(a, b, msg=None): """Assert a != b, with repr messaging on failure.""" assert a != b, msg or "%r == %r" % (a, b) + def is_(a, b, msg=None): """Assert a is b, with repr messaging on failure.""" assert a is b, msg or "%r is not %r" % (a, b) + def is_not_(a, b, msg=None): """Assert a is not b, with repr messaging on failure.""" assert a is not b, msg or "%r is %r" % (a, b) + def startswith_(a, fragment, msg=None): """Assert a.startswith(fragment), with repr messaging on failure.""" assert a.startswith(fragment), msg or "%r does not start with %r" % ( a, fragment) + def assert_raises(except_cls, callable_, *args, **kw): try: callable_(*args, **kw) @@ -161,6 +166,7 @@ def assert_raises(except_cls, callable_, *args, **kw): # assert outside the block so it works for AssertionError too ! assert success, "Callable did not raise an exception" + def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): try: callable_(*args, **kwargs) @@ -214,7 +220,9 @@ class AssertsCompiledSQL(object): p = c.construct_params(params) eq_(tuple([p[x] for x in c.positiontup]), checkpositional) + class ComparesTables(object): + def assert_tables_equal(self, table, reflected_table, strict_types=False): assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): @@ -224,15 +232,19 @@ class ComparesTables(object): eq_(c.nullable, reflected_c.nullable) if strict_types: + msg = "Type '%s' doesn't correspond to type '%s'" assert type(reflected_c.type) is type(c.type), \ - "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + msg % (reflected_c.type, c.type) else: self.assert_types_base(reflected_c, c) if isinstance(c.type, sqltypes.String): eq_(c.type.length, reflected_c.type.length) - eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) + eq_( + set([f.column.name for f in c.foreign_keys]), + set([f.column.name for f in reflected_c.foreign_keys]) + ) if c.server_default: assert isinstance(reflected_c.server_default, schema.FetchedValue) @@ -246,6 +258,7 @@ class ComparesTables(object): "On column %r, type '%s' doesn't correspond to type '%s'" % \ (c1.name, c1.type, c2.type) + class AssertsExecutionResults(object): def assert_result(self, result, class_, *objects): result = list(result) @@ -296,6 +309,7 @@ class AssertsExecutionResults(object): len(found), len(expected))) NOVALUE = object() + def _compare_item(obj, spec): for key, value in spec.iteritems(): if isinstance(value, tuple): @@ -347,7 +361,8 @@ class AssertsExecutionResults(object): self.assert_sql_execution(db, callable_, *newrules) def assert_sql_count(self, db, callable_, count): - self.assert_sql_execution(db, callable_, assertsql.CountStatements(count)) + self.assert_sql_execution( + db, callable_, assertsql.CountStatements(count)) @contextlib.contextmanager def assert_execution(self, *rules): @@ -359,4 +374,4 @@ class AssertsExecutionResults(object): assertsql.asserter.clear_rules() def assert_statement_count(self, count): - return self.assert_execution(assertsql.CountStatements(count))
\ No newline at end of file + return self.assert_execution(assertsql.CountStatements(count)) diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 08ee55d57..d955d1554 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -3,6 +3,7 @@ from ..engine.default import DefaultDialect from .. import util import re + class AssertRule(object): def process_execute(self, clauseelement, *multiparams, **params): @@ -40,6 +41,7 @@ class AssertRule(object): assert False, 'Rule has not been consumed' return self.is_consumed() + class SQLMatchRule(AssertRule): def __init__(self): self._result = None @@ -56,6 +58,7 @@ class SQLMatchRule(AssertRule): return True + class ExactSQL(SQLMatchRule): def __init__(self, sql, params=None): @@ -138,6 +141,7 @@ class RegexSQL(SQLMatchRule): _received_statement, _received_parameters) + class CompiledSQL(SQLMatchRule): def __init__(self, statement, params): @@ -217,6 +221,7 @@ class CountStatements(AssertRule): % (self.count, self._statement_count) return True + class AllOf(AssertRule): def __init__(self, *rules): @@ -244,6 +249,7 @@ class AllOf(AssertRule): def consume_final(self): return len(self.rules) == 0 + def _process_engine_statement(query, context): if util.jython: @@ -256,6 +262,7 @@ def _process_engine_statement(query, context): query = re.sub(r'\n', '', query) return query + def _process_assertion_statement(query, context): paramstyle = context.dialect.paramstyle if paramstyle == 'named': @@ -275,6 +282,7 @@ def _process_assertion_statement(query, context): return query + class SQLAssert(object): rules = None @@ -311,4 +319,3 @@ class SQLAssert(object): executemany) asserter = SQLAssert() - diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 2945bd456..ae4f585e1 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -1,3 +1,2 @@ requirements = None db = None - diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 9d15c5078..20bcf0317 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -9,7 +9,9 @@ from .. import event, pool import re import warnings + class ConnectionKiller(object): + def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() self.testing_engines = weakref.WeakKeyDictionary() @@ -83,12 +85,14 @@ class ConnectionKiller(object): testing_reaper = ConnectionKiller() + def drop_all_tables(metadata, bind): testing_reaper.close_all() if hasattr(bind, 'close'): bind.close() metadata.drop_all(bind) + @decorator def assert_conns_closed(fn, *args, **kw): try: @@ -96,6 +100,7 @@ def assert_conns_closed(fn, *args, **kw): finally: testing_reaper.assert_all_closed() + @decorator def rollback_open_connections(fn, *args, **kw): """Decorator that rolls back all open connections after fn execution.""" @@ -105,6 +110,7 @@ def rollback_open_connections(fn, *args, **kw): finally: testing_reaper.rollback_all() + @decorator def close_first(fn, *args, **kw): """Decorator that closes all connections before fn execution.""" @@ -121,6 +127,7 @@ def close_open_connections(fn, *args, **kw): finally: testing_reaper.close_all() + def all_dialects(exclude=None): import sqlalchemy.databases as d for name in d.__all__: @@ -129,10 +136,13 @@ def all_dialects(exclude=None): continue mod = getattr(d, name, None) if not mod: - mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) + mod = getattr(__import__( + 'sqlalchemy.databases.%s' % name).databases, name) yield mod.dialect() + class ReconnectFixture(object): + def __init__(self, dbapi): self.dbapi = dbapi self.connections = [] @@ -165,6 +175,7 @@ class ReconnectFixture(object): self._safe(c.close) self.connections = [] + def reconnecting_engine(url=None, options=None): url = url or config.db_url dbapi = config.db.dialect.dbapi @@ -173,9 +184,11 @@ def reconnecting_engine(url=None, options=None): options['module'] = ReconnectFixture(dbapi) engine = testing_engine(url, options) _dispose = engine.dispose + def dispose(): engine.dialect.dbapi.shutdown() _dispose() + engine.test_shutdown = engine.dialect.dbapi.shutdown engine.dispose = dispose return engine @@ -209,6 +222,7 @@ def testing_engine(url=None, options=None): return engine + def utf8_engine(url=None, options=None): """Hook for dialects or drivers that don't handle utf8 by default.""" @@ -226,6 +240,7 @@ def utf8_engine(url=None, options=None): return testing_engine(url, options) + def mock_engine(dialect_name=None): """Provides a mocking engine based on the current testing.db. @@ -244,17 +259,21 @@ def mock_engine(dialect_name=None): dialect_name = config.db.name buffer = [] + def executor(sql, *a, **kw): buffer.append(sql) + def assert_sql(stmts): recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] assert recv == stmts, recv + def print_sql(): d = engine.dialect return "\n".join( str(s.compile(dialect=d)) for s in engine.mock ) + engine = create_engine(dialect_name + '://', strategy='mock', executor=executor) assert not hasattr(engine, 'mock') @@ -263,6 +282,7 @@ def mock_engine(dialect_name=None): engine.print_sql = print_sql return engine + class DBAPIProxyCursor(object): """Proxy a DBAPI cursor. @@ -287,6 +307,7 @@ class DBAPIProxyCursor(object): def __getattr__(self, key): return getattr(self.cursor, key) + class DBAPIProxyConnection(object): """Proxy a DBAPI connection. @@ -308,14 +329,17 @@ class DBAPIProxyConnection(object): def __getattr__(self, key): return getattr(self.conn, key) -def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor): + +def proxying_engine(conn_cls=DBAPIProxyConnection, + cursor_cls=DBAPIProxyCursor): """Produce an engine that provides proxy hooks for common methods. """ def mock_conn(): return conn_cls(config.db, cursor_cls) - return testing_engine(options={'creator':mock_conn}) + return testing_engine(options={'creator': mock_conn}) + class ReplayableSession(object): """A simple record/playback tool. @@ -427,4 +451,3 @@ class ReplayableSession(object): raise AttributeError(key) else: return result - diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py index 1b24e73b7..5c5e69154 100644 --- a/lib/sqlalchemy/testing/entities.py +++ b/lib/sqlalchemy/testing/entities.py @@ -2,7 +2,10 @@ import sqlalchemy as sa from sqlalchemy import exc as sa_exc _repr_stack = set() + + class BasicEntity(object): + def __init__(self, **kw): for key, value in kw.iteritems(): setattr(self, key, value) @@ -21,7 +24,10 @@ class BasicEntity(object): _repr_stack.remove(id(self)) _recursion_stack = set() + + class ComparableEntity(BasicEntity): + def __hash__(self): return hash(self.__class__) diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 3c70ec8d9..f105c8b6a 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -61,6 +61,7 @@ class skip_if(object): self._fails_on = skip_if(other, reason) return self + class fails_if(skip_if): def __call__(self, fn): @decorator @@ -69,14 +70,17 @@ class fails_if(skip_if): return fn(*args, **kw) return decorate(fn) + def only_if(predicate, reason=None): predicate = _as_predicate(predicate) return skip_if(NotPredicate(predicate), reason) + def succeeds_if(predicate, reason=None): predicate = _as_predicate(predicate) return fails_if(NotPredicate(predicate), reason) + class Predicate(object): @classmethod def as_predicate(cls, predicate): @@ -93,6 +97,7 @@ class Predicate(object): else: assert False, "unknown predicate type: %s" % predicate + class BooleanPredicate(Predicate): def __init__(self, value, description=None): self.value = value @@ -110,6 +115,7 @@ class BooleanPredicate(Predicate): def __str__(self): return self._as_string() + class SpecPredicate(Predicate): def __init__(self, db, op=None, spec=None, description=None): self.db = db @@ -177,6 +183,7 @@ class SpecPredicate(Predicate): def __str__(self): return self._as_string() + class LambdaPredicate(Predicate): def __init__(self, lambda_, description=None, args=None, kw=None): self.lambda_ = lambda_ @@ -201,6 +208,7 @@ class LambdaPredicate(Predicate): def __str__(self): return self._as_string() + class NotPredicate(Predicate): def __init__(self, predicate): self.predicate = predicate @@ -211,6 +219,7 @@ class NotPredicate(Predicate): def __str__(self): return self.predicate._as_string(True) + class OrPredicate(Predicate): def __init__(self, predicates, description=None): self.predicates = predicates @@ -256,9 +265,11 @@ class OrPredicate(Predicate): _as_predicate = Predicate.as_predicate + def _is_excluded(db, op, spec): return SpecPredicate(db, op, spec)() + def _server_version(engine): """Return a server_version_info tuple.""" @@ -268,24 +279,30 @@ def _server_version(engine): conn.close() return version + def db_spec(*dbs): return OrPredicate( Predicate.as_predicate(db) for db in dbs ) + def open(): return skip_if(BooleanPredicate(False, "mark as execute")) + def closed(): return skip_if(BooleanPredicate(True, "marked as skip")) + @decorator def future(fn, *args, **kw): return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature") + def fails_on(db, reason=None): return fails_if(SpecPredicate(db), reason) + def fails_on_everything_except(*dbs): return succeeds_if( OrPredicate([ @@ -293,9 +310,11 @@ def fails_on_everything_except(*dbs): ]) ) + def skip(db, reason=None): return skip_if(SpecPredicate(db), reason) + def only_on(dbs, reason=None): return only_if( OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)]) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 1a1204898..5c587cb2f 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -7,6 +7,7 @@ import sys import sqlalchemy as sa from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta + class TestBase(object): # A sequence of database names to always run, regardless of the # constraints below. @@ -29,6 +30,7 @@ class TestBase(object): def assert_(self, val, msg=None): assert val, msg + class TablesTest(TestBase): # 'once', None @@ -208,9 +210,11 @@ class _ORMTest(object): sa.orm.session.Session.close_all() sa.orm.clear_mappers() + class ORMTest(_ORMTest, TestBase): pass + class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): # 'once', 'each', None run_setup_classes = 'once' @@ -252,7 +256,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): cls.classes.clear() _ORMTest.teardown_class() - @classmethod def _setup_once_classes(cls): if cls.run_setup_classes == 'once': @@ -275,18 +278,21 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): """ cls_registry = cls.classes + class FindFixture(type): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls return type.__init__(cls, classname, bases, dict_) - class _Base(object): __metaclass__ = FindFixture + class Basic(BasicEntity, _Base): pass + class Comparable(ComparableEntity, _Base): pass + cls.Basic = Basic cls.Comparable = Comparable fn() @@ -306,6 +312,7 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): def setup_mappers(cls): pass + class DeclarativeMappedTest(MappedTest): run_setup_classes = 'once' run_setup_mappers = 'once' @@ -317,17 +324,21 @@ class DeclarativeMappedTest(MappedTest): @classmethod def _with_register_classes(cls, fn): cls_registry = cls.classes + class FindFixtureDeclarative(DeclarativeMeta): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls return DeclarativeMeta.__init__( cls, classname, bases, dict_) + class DeclarativeBasic(object): __table_cls__ = schema.Table + _DeclBase = declarative_base(metadata=cls.metadata, metaclass=FindFixtureDeclarative, cls=DeclarativeBasic) cls.DeclarativeBasic = _DeclBase fn() + if cls.metadata.tables: cls.metadata.create_all(config.db) diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py index f5b8b827c..09d51b5fa 100644 --- a/lib/sqlalchemy/testing/pickleable.py +++ b/lib/sqlalchemy/testing/pickleable.py @@ -1,43 +1,59 @@ -"""Classes used in pickling tests, need to be at the module level for unpickling.""" +"""Classes used in pickling tests, need to be at the module level for +unpickling. +""" from . import fixtures + class User(fixtures.ComparableEntity): pass + class Order(fixtures.ComparableEntity): pass + class Dingaling(fixtures.ComparableEntity): pass + class EmailUser(User): pass + class Address(fixtures.ComparableEntity): pass + # TODO: these are kind of arbitrary.... class Child1(fixtures.ComparableEntity): pass + class Child2(fixtures.ComparableEntity): pass + class Parent(fixtures.ComparableEntity): pass + class Screen(object): + def __init__(self, obj, parent=None): self.obj = obj self.parent = parent + class Foo(object): + def __init__(self, moredata): self.data = 'im data' self.stuff = 'im stuff' self.moredata = moredata + __hash__ = object.__hash__ + def __eq__(self, other): return other.data == self.data and \ other.stuff == self.stuff and \ @@ -45,40 +61,53 @@ class Foo(object): class Bar(object): + def __init__(self, x, y): self.x = x self.y = y + __hash__ = object.__hash__ + def __eq__(self, other): return other.__class__ is self.__class__ and \ other.x == self.x and \ other.y == self.y + def __str__(self): return "Bar(%d, %d)" % (self.x, self.y) + class OldSchool: + def __init__(self, x, y): self.x = x self.y = y + def __eq__(self, other): return other.__class__ is self.__class__ and \ other.x == self.x and \ other.y == self.y + class OldSchoolWithoutCompare: + def __init__(self, x, y): self.x = x self.y = y + class BarWithoutCompare(object): + def __init__(self, x, y): self.x = x self.y = y + def __str__(self): return "Bar(%d, %d)" % (self.x, self.y) class NotComparable(object): + def __init__(self, data): self.data = data @@ -93,6 +122,7 @@ class NotComparable(object): class BrokenComparable(object): + def __init__(self, data): self.data = data @@ -104,4 +134,3 @@ class BrokenComparable(object): def __ne__(self, other): raise NotImplementedError - diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py index 37f7b29f5..c104c4614 100644 --- a/lib/sqlalchemy/testing/plugin/noseplugin.py +++ b/lib/sqlalchemy/testing/plugin/noseplugin.py @@ -41,6 +41,7 @@ db_opts = {} options = None _existing_engine = None + def _log(option, opt_str, value, parser): global logging if not logging: @@ -59,34 +60,42 @@ def _list_dbs(*args): print "%20s\t%s" % (macro, file_config.get('db', macro)) sys.exit(0) + def _server_side_cursors(options, opt_str, value, parser): db_opts['server_side_cursors'] = True + def _engine_strategy(options, opt_str, value, parser): if value: db_opts['strategy'] = value pre_configure = [] post_configure = [] + + def pre(fn): pre_configure.append(fn) return fn + + def post(fn): post_configure.append(fn) return fn + @pre def _setup_options(opt, file_config): global options options = opt + @pre def _monkeypatch_cdecimal(options, file_config): if options.cdecimal: - import sys import cdecimal sys.modules['decimal'] = cdecimal + @post def _engine_uri(options, file_config): global db_label, db_url @@ -105,6 +114,7 @@ def _engine_uri(options, file_config): % db_label) db_url = file_config.get('db', db_label) + @post def _require(options, file_config): if not(options.require or @@ -131,12 +141,14 @@ def _require(options, file_config): continue pkg_resources.require(requirement) + @post def _engine_pool(options, file_config): if options.mockpool: from sqlalchemy import pool db_opts['poolclass'] = pool.AssertionPool + @post def _create_testing_engine(options, file_config): from sqlalchemy.testing import engines, config @@ -199,6 +211,7 @@ def _set_table_options(options, file_config): if options.mysql_engine: table_options['mysql_engine'] = options.mysql_engine + @post def _reverse_topological(options, file_config): if options.reversetop: @@ -208,6 +221,7 @@ def _reverse_topological(options, file_config): topological.set = unitofwork.set = session.set = mapper.set = \ dependency.set = RandomSet + @post def _requirements(options, file_config): from sqlalchemy.testing import config @@ -230,6 +244,7 @@ def _post_setup_options(opt, file_config): from sqlalchemy.testing import config config.options = options + @post def _setup_profiling(options, file_config): from sqlalchemy.testing import profiling diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index a22e83cbc..ae9d176b7 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -22,6 +22,7 @@ from ..util.compat import jython, pypy, win32 _current_test = None + def profiled(target=None, **target_opts): """Function profiling. @@ -69,13 +70,13 @@ def profiled(target=None, **target_opts): else: stats.print_stats() - print_callers = target_opts.get('print_callers', - profile_config['print_callers']) + print_callers = target_opts.get( + 'print_callers', profile_config['print_callers']) if print_callers: stats.print_callers() - print_callees = target_opts.get('print_callees', - profile_config['print_callees']) + print_callees = target_opts.get( + 'print_callees', profile_config['print_callees']) if print_callees: stats.print_callees() @@ -92,10 +93,14 @@ class ProfileStatsFile(object): """ def __init__(self, filename): - self.write = config.options is not None and config.options.write_profiles + self.write = ( + config.options is not None and + config.options.write_profiles + ) self.fname = os.path.abspath(filename) self.short_fname = os.path.split(self.fname)[-1] - self.data = collections.defaultdict(lambda: collections.defaultdict(dict)) + self.data = collections.defaultdict( + lambda: collections.defaultdict(dict)) self._read() if self.write: # rewrite for the case where features changed, @@ -124,7 +129,10 @@ class ProfileStatsFile(object): def has_stats(self): test_key = _current_test - return test_key in self.data and self.platform_key in self.data[test_key] + return ( + test_key in self.data and + self.platform_key in self.data[test_key] + ) def result(self, callcount): test_key = _current_test @@ -153,7 +161,6 @@ class ProfileStatsFile(object): per_platform['current_count'] += 1 return result - def _header(self): return \ "# %s\n"\ @@ -165,8 +172,8 @@ class ProfileStatsFile(object): "# assertions are raised if the counts do not match.\n"\ "# \n"\ "# To add a new callcount test, apply the function_call_count \n"\ - "# decorator and re-run the tests using the --write-profiles option - \n"\ - "# this file will be rewritten including the new count.\n"\ + "# decorator and re-run the tests using the --write-profiles \n"\ + "# option - this file will be rewritten including the new count.\n"\ "# \n"\ "" % (self.fname) @@ -183,7 +190,8 @@ class ProfileStatsFile(object): test_key, platform_key, counts = line.split() per_fn = self.data[test_key] per_platform = per_fn[platform_key] - per_platform['counts'] = [int(count) for count in counts.split(",")] + c = [int(count) for count in counts.split(",")] + per_platform['counts'] = c per_platform['lineno'] = lineno + 1 per_platform['current_count'] = 0 profile_f.close() @@ -198,16 +206,13 @@ class ProfileStatsFile(object): profile_f.write("\n# TEST: %s\n\n" % test_key) for platform_key in sorted(per_fn): per_platform = per_fn[platform_key] - profile_f.write( - "%s %s %s\n" % ( - test_key, - platform_key, ",".join(str(count) for count in per_platform['counts']) - ) - ) + c = ",".join(str(count) for count in per_platform['counts']) + profile_f.write("%s %s %s\n" % (test_key, platform_key, c)) profile_f.close() from sqlalchemy.util.compat import update_wrapper + def function_call_count(variance=0.05): """Assert a target for a test case's function call count. @@ -222,7 +227,6 @@ def function_call_count(variance=0.05): def decorate(fn): def wrap(*args, **kw): - if cProfile is None: raise SkipTest("cProfile is not installed") @@ -237,7 +241,6 @@ def function_call_count(variance=0.05): gc_collect() - timespent, load_stats, fn_result = _profile( fn, *args, **kw ) @@ -263,8 +266,9 @@ def function_call_count(variance=0.05): if abs(callcount - expected_count) > deviance: raise AssertionError( "Adjusted function call count %s not within %s%% " - "of expected %s. (Delete line %d of file %s to regenerate " - "this callcount, when tests are run with --write-profiles.)" + "of expected %s. (Delete line %d of file %s to " + "regenerate this callcount, when tests are run " + "with --write-profiles.)" % ( callcount, (variance * 100), expected_count, line_no, @@ -288,4 +292,3 @@ def _profile(fn, *args, **kw): ended = time.time() return ended - began, load_stats, locals()['result'] - diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index d58538db9..68659a855 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -10,6 +10,7 @@ to provide specific inclusion/exlusions. from . import exclusions + class Requirements(object): def __init__(self, db, config): self.db = db @@ -178,7 +179,6 @@ class SuiteRequirements(Requirements): """ return exclusions.open() - @property def datetime(self): """target dialect supports representation of Python @@ -237,8 +237,10 @@ class SuiteRequirements(Requirements): @property def empty_strings_varchar(self): - """target database can persist/return an empty string with a varchar.""" + """target database can persist/return an empty string with a + varchar. + """ return exclusions.open() @property @@ -248,7 +250,6 @@ class SuiteRequirements(Requirements): return exclusions.open() - @property def update_from(self): """Target must support UPDATE..FROM syntax""" diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py index 1a4ba5212..6ec73d7c8 100644 --- a/lib/sqlalchemy/testing/runner.py +++ b/lib/sqlalchemy/testing/runner.py @@ -28,5 +28,6 @@ from sqlalchemy.testing.plugin.noseplugin import NoseSQLAlchemy import nose + def main(): nose.main(addplugins=[NoseSQLAlchemy()]) diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 5dfdc0e07..ad233ec22 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -7,6 +7,7 @@ __all__ = 'Table', 'Column', table_options = {} + def Table(*args, **kw): """A schema.Table wrapper/hook for dialect-specific tweaks.""" @@ -76,10 +77,10 @@ def Column(*args, **kw): event.listen(col, 'after_parent_attach', add_seq, propagate=True) return col + def _truncate_name(dialect, name): if len(name) > dialect.max_identifier_length: return name[0:max(dialect.max_identifier_length - 6, 0)] + \ "_" + hex(hash(name) % 64)[2:] else: return name - diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py index 466429aa5..c5b162413 100644 --- a/lib/sqlalchemy/testing/suite/test_ddl.py +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -25,7 +25,6 @@ class TableDDLTest(fixtures.TestBase): (1, 'some data') ) - @requirements.create_table @util.provide_metadata def test_create_table(self): @@ -35,7 +34,6 @@ class TableDDLTest(fixtures.TestBase): ) self._simple_roundtrip() - @requirements.drop_table @util.provide_metadata def test_drop_table(self): @@ -48,4 +46,4 @@ class TableDDLTest(fixtures.TestBase): ) -__all__ = ('TableDDLTest', )
\ No newline at end of file +__all__ = ('TableDDLTest', ) diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index 3cd7d39bc..b2b2a0aa8 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -8,6 +8,7 @@ from sqlalchemy import Integer, String, select, util from ..schema import Table, Column + class LastrowidTest(fixtures.TablesTest): run_deletes = 'each' @@ -88,7 +89,6 @@ class InsertBehaviorTest(fixtures.TablesTest): else: engine = config.db - r = engine.execute( self.tables.autoinc_pk.insert(), data="some data" @@ -107,6 +107,7 @@ class InsertBehaviorTest(fixtures.TablesTest): assert r.is_insert assert not r.returns_rows + class ReturningTest(fixtures.TablesTest): run_deletes = 'each' __requires__ = 'returning', 'autoincrement_insert' @@ -162,5 +163,3 @@ class ReturningTest(fixtures.TablesTest): __all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest') - - diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index a7c814db7..b9894347a 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -18,6 +18,7 @@ from sqlalchemy import event metadata, users = None, None + class HasTableTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): @@ -31,6 +32,7 @@ class HasTableTest(fixtures.TablesTest): assert config.db.dialect.has_table(conn, "test_table") assert not config.db.dialect.has_table(conn, "nonexistent_table") + class HasSequenceTest(fixtures.TestBase): __requires__ = 'sequences', @@ -425,4 +427,4 @@ class ComponentReflectionTest(fixtures.TablesTest): self._test_get_table_oid('users', schema='test_schema') -__all__ = ('ComponentReflectionTest', 'HasSequenceTest', 'HasTableTest')
\ No newline at end of file +__all__ = ('ComponentReflectionTest', 'HasSequenceTest', 'HasTableTest') diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 74cb52c6e..8d0500d71 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -8,6 +8,7 @@ from sqlalchemy import Date, DateTime, Time, MetaData, String from ..schema import Table, Column import datetime + class _UnicodeFixture(object): __requires__ = 'unicode_data', @@ -70,7 +71,6 @@ class _UnicodeFixture(object): for row in rows: assert isinstance(row[0], unicode) - def _test_empty_strings(self): unicode_table = self.tables.unicode_table @@ -83,16 +83,17 @@ class _UnicodeFixture(object): ).first() eq_(row, (u'',)) + class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): __requires__ = 'unicode_data', datatype = Unicode(255) - @requirements.empty_strings_varchar def test_empty_strings_varchar(self): self._test_empty_strings() + class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): __requires__ = 'unicode_data', 'text_type' @@ -114,6 +115,7 @@ class StringTest(fixtures.TestBase): foo.create(config.db) foo.drop(config.db) + class _DateFixture(object): compare = None @@ -165,37 +167,44 @@ class DateTimeTest(_DateFixture, fixtures.TablesTest): datatype = DateTime data = datetime.datetime(2012, 10, 15, 12, 57, 18) + class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): __requires__ = 'datetime_microseconds', datatype = DateTime data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + class TimeTest(_DateFixture, fixtures.TablesTest): __requires__ = 'time', datatype = Time data = datetime.time(12, 57, 18) + class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): __requires__ = 'time_microseconds', datatype = Time data = datetime.time(12, 57, 18, 396) + class DateTest(_DateFixture, fixtures.TablesTest): __requires__ = 'date', datatype = Date data = datetime.date(2012, 10, 15) + class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): __requires__ = 'date', datatype = Date data = datetime.datetime(2012, 10, 15, 12, 57, 18) compare = datetime.date(2012, 10, 15) + class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest): __requires__ = 'datetime_historic', datatype = DateTime data = datetime.datetime(1850, 11, 10, 11, 52, 35) + class DateHistoricTest(_DateFixture, fixtures.TablesTest): __requires__ = 'date_historic', datatype = Date @@ -207,6 +216,3 @@ __all__ = ('UnicodeVarcharTest', 'UnicodeTextTest', 'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest', 'TimeMicrosecondsTest', 'TimeTest', 'DateTimeMicrosecondsTest', 'DateHistoricTest', 'StringTest') - - - diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py index e73b05485..a3456ac2a 100644 --- a/lib/sqlalchemy/testing/suite/test_update_delete.py +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -1,9 +1,7 @@ from .. import fixtures, config -from ..config import requirements from ..assertions import eq_ -from .. import engines -from sqlalchemy import Integer, String, select +from sqlalchemy import Integer, String from ..schema import Table, Column @@ -61,4 +59,4 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest): ] ) -__all__ = ('SimpleUpdateDeleteTest', )
\ No newline at end of file +__all__ = ('SimpleUpdateDeleteTest', ) diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 41b0a30b3..2592c341e 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -26,9 +26,11 @@ elif pypy: else: # assume CPython - straight gc.collect, lazy_gc() is a pass gc_collect = gc.collect + def lazy_gc(): pass + def picklers(): picklers = set() # Py2K @@ -56,6 +58,7 @@ def round_decimal(value, prec): ).to_integral(decimal.ROUND_FLOOR) / \ pow(10, prec) + class RandomSet(set): def __iter__(self): l = list(set.__iter__(self)) @@ -80,6 +83,7 @@ class RandomSet(set): def copy(self): return RandomSet(self) + def conforms_partial_ordering(tuples, sorted_elements): """True if the given sorting conforms to the given partial ordering.""" @@ -93,6 +97,7 @@ def conforms_partial_ordering(tuples, sorted_elements): else: return True + def all_partial_orderings(tuples, elements): edges = defaultdict(set) for parent, child in tuples: @@ -131,7 +136,6 @@ def function_named(fn, name): return fn - def run_as_contextmanager(ctx, fn, *arg, **kw): """Run the given function under the given contextmanager, simulating the behavior of 'with' to support older @@ -152,6 +156,7 @@ def run_as_contextmanager(ctx, fn, *arg, **kw): else: return raise_ + def rowset(results): """Converts the results of sql execution into a plain set of column tuples. @@ -182,6 +187,7 @@ def provide_metadata(fn, *args, **kw): metadata.drop_all() self.metadata = prev_meta + class adict(dict): """Dict keys available as attributes. Shadows.""" def __getattribute__(self, key): @@ -192,5 +198,3 @@ class adict(dict): def get_all(self, *keys): return tuple([self[key] for key in keys]) - - diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 7afcc63c5..41f3dbfed 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -4,6 +4,7 @@ import warnings from .. import exc as sa_exc from .. import util + def testing_warn(msg, stacklevel=3): """Replaces sqlalchemy.util.warn during tests.""" @@ -14,6 +15,7 @@ def testing_warn(msg, stacklevel=3): else: warnings.warn_explicit(msg, filename, lineno) + def resetwarnings(): """Reset warning behavior to testing defaults.""" @@ -24,6 +26,7 @@ def resetwarnings(): warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) warnings.filterwarnings('error', category=sa_exc.SAWarning) + def assert_warnings(fn, warnings): """Assert that each of the given warnings are emitted by fn.""" @@ -31,6 +34,7 @@ def assert_warnings(fn, warnings): canary = [] orig_warn = util.warn + def capture_warnings(*args, **kw): orig_warn(*args, **kw) popwarn = warnings.pop(0) |