diff options
Diffstat (limited to 'lib/sqlalchemy/testing/engines.py')
-rw-r--r-- | lib/sqlalchemy/testing/engines.py | 60 |
1 files changed, 30 insertions, 30 deletions
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index d17e30edf..074e3b338 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -16,7 +16,6 @@ import warnings class ConnectionKiller(object): - def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() self.testing_engines = weakref.WeakKeyDictionary() @@ -39,8 +38,8 @@ class ConnectionKiller(object): fn() except Exception as e: warnings.warn( - "testing_reaper couldn't " - "rollback/close connection: %s" % e) + "testing_reaper couldn't " "rollback/close connection: %s" % e + ) def rollback_all(self): for rec in list(self.proxy_refs): @@ -97,18 +96,19 @@ class ConnectionKiller(object): if rec.is_valid: assert False + testing_reaper = ConnectionKiller() def drop_all_tables(metadata, bind): testing_reaper.close_all() - if hasattr(bind, 'close'): + if hasattr(bind, "close"): bind.close() if not config.db.dialect.supports_alter: from . import assertions - with assertions.expect_warnings( - "Can't sort tables", assert_=False): + + with assertions.expect_warnings("Can't sort tables", assert_=False): metadata.drop_all(bind) else: metadata.drop_all(bind) @@ -151,19 +151,20 @@ def close_open_connections(fn, *args, **kw): def all_dialects(exclude=None): import sqlalchemy.databases as d + for name in d.__all__: # TEMPORARY if exclude and name in exclude: continue mod = getattr(d, name, None) if not mod: - mod = getattr(__import__( - 'sqlalchemy.databases.%s' % name).databases, name) + mod = getattr( + __import__("sqlalchemy.databases.%s" % name).databases, name + ) yield mod.dialect() class ReconnectFixture(object): - def __init__(self, dbapi): self.dbapi = dbapi self.connections = [] @@ -191,8 +192,8 @@ class ReconnectFixture(object): fn() except Exception as e: warnings.warn( - "ReconnectFixture couldn't " - "close connection: %s" % e) + "ReconnectFixture couldn't " "close connection: %s" % e + ) def shutdown(self, stop=False): # TODO: this doesn't cover all cases @@ -214,7 +215,7 @@ def reconnecting_engine(url=None, options=None): dbapi = config.db.dialect.dbapi if not options: options = {} - options['module'] = ReconnectFixture(dbapi) + options["module"] = ReconnectFixture(dbapi) engine = testing_engine(url, options) _dispose = engine.dispose @@ -238,7 +239,7 @@ def testing_engine(url=None, options=None): if not options: use_reaper = True else: - use_reaper = options.pop('use_reaper', True) + use_reaper = options.pop("use_reaper", True) url = url or config.db.url @@ -253,15 +254,15 @@ def testing_engine(url=None, options=None): default_opt.update(options) engine = create_engine(url, **options) - engine._has_events = True # enable event blocks, helps with profiling + engine._has_events = True # enable event blocks, helps with profiling if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 engine.pool._max_overflow = 0 if use_reaper: - event.listen(engine.pool, 'connect', testing_reaper.connect) - event.listen(engine.pool, 'checkout', testing_reaper.checkout) - event.listen(engine.pool, 'invalidate', testing_reaper.invalidate) + event.listen(engine.pool, "connect", testing_reaper.connect) + event.listen(engine.pool, "checkout", testing_reaper.checkout) + event.listen(engine.pool, "invalidate", testing_reaper.invalidate) testing_reaper.add_engine(engine) return engine @@ -290,19 +291,17 @@ def mock_engine(dialect_name=None): buffer.append(sql) def assert_sql(stmts): - recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer] assert recv == stmts, recv def print_sql(): d = engine.dialect - return "\n".join( - str(s.compile(dialect=d)) - for s in engine.mock - ) - - engine = create_engine(dialect_name + '://', - strategy='mock', executor=executor) - assert not hasattr(engine, 'mock') + return "\n".join(str(s.compile(dialect=d)) for s in engine.mock) + + engine = create_engine( + dialect_name + "://", strategy="mock", executor=executor + ) + assert not hasattr(engine, "mock") engine.mock = buffer engine.assert_sql = assert_sql engine.print_sql = print_sql @@ -358,14 +357,15 @@ class DBAPIProxyConnection(object): return getattr(self.conn, key) -def proxying_engine(conn_cls=DBAPIProxyConnection, - cursor_cls=DBAPIProxyCursor): +def proxying_engine( + conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor +): """Produce an engine that provides proxy hooks for common methods. """ + def mock_conn(): return conn_cls(config.db, cursor_cls) - return testing_engine(options={'creator': mock_conn}) - + return testing_engine(options={"creator": mock_conn}) |