summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/engines.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/engines.py')
-rw-r--r--lib/sqlalchemy/testing/engines.py60
1 files changed, 30 insertions, 30 deletions
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index d17e30edf..074e3b338 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -16,7 +16,6 @@ import warnings
class ConnectionKiller(object):
-
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
self.testing_engines = weakref.WeakKeyDictionary()
@@ -39,8 +38,8 @@ class ConnectionKiller(object):
fn()
except Exception as e:
warnings.warn(
- "testing_reaper couldn't "
- "rollback/close connection: %s" % e)
+ "testing_reaper couldn't " "rollback/close connection: %s" % e
+ )
def rollback_all(self):
for rec in list(self.proxy_refs):
@@ -97,18 +96,19 @@ class ConnectionKiller(object):
if rec.is_valid:
assert False
+
testing_reaper = ConnectionKiller()
def drop_all_tables(metadata, bind):
testing_reaper.close_all()
- if hasattr(bind, 'close'):
+ if hasattr(bind, "close"):
bind.close()
if not config.db.dialect.supports_alter:
from . import assertions
- with assertions.expect_warnings(
- "Can't sort tables", assert_=False):
+
+ with assertions.expect_warnings("Can't sort tables", assert_=False):
metadata.drop_all(bind)
else:
metadata.drop_all(bind)
@@ -151,19 +151,20 @@ def close_open_connections(fn, *args, **kw):
def all_dialects(exclude=None):
import sqlalchemy.databases as d
+
for name in d.__all__:
# TEMPORARY
if exclude and name in exclude:
continue
mod = getattr(d, name, None)
if not mod:
- mod = getattr(__import__(
- 'sqlalchemy.databases.%s' % name).databases, name)
+ mod = getattr(
+ __import__("sqlalchemy.databases.%s" % name).databases, name
+ )
yield mod.dialect()
class ReconnectFixture(object):
-
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
@@ -191,8 +192,8 @@ class ReconnectFixture(object):
fn()
except Exception as e:
warnings.warn(
- "ReconnectFixture couldn't "
- "close connection: %s" % e)
+ "ReconnectFixture couldn't " "close connection: %s" % e
+ )
def shutdown(self, stop=False):
# TODO: this doesn't cover all cases
@@ -214,7 +215,7 @@ def reconnecting_engine(url=None, options=None):
dbapi = config.db.dialect.dbapi
if not options:
options = {}
- options['module'] = ReconnectFixture(dbapi)
+ options["module"] = ReconnectFixture(dbapi)
engine = testing_engine(url, options)
_dispose = engine.dispose
@@ -238,7 +239,7 @@ def testing_engine(url=None, options=None):
if not options:
use_reaper = True
else:
- use_reaper = options.pop('use_reaper', True)
+ use_reaper = options.pop("use_reaper", True)
url = url or config.db.url
@@ -253,15 +254,15 @@ def testing_engine(url=None, options=None):
default_opt.update(options)
engine = create_engine(url, **options)
- engine._has_events = True # enable event blocks, helps with profiling
+ engine._has_events = True # enable event blocks, helps with profiling
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
engine.pool._max_overflow = 0
if use_reaper:
- event.listen(engine.pool, 'connect', testing_reaper.connect)
- event.listen(engine.pool, 'checkout', testing_reaper.checkout)
- event.listen(engine.pool, 'invalidate', testing_reaper.invalidate)
+ event.listen(engine.pool, "connect", testing_reaper.connect)
+ event.listen(engine.pool, "checkout", testing_reaper.checkout)
+ event.listen(engine.pool, "invalidate", testing_reaper.invalidate)
testing_reaper.add_engine(engine)
return engine
@@ -290,19 +291,17 @@ def mock_engine(dialect_name=None):
buffer.append(sql)
def assert_sql(stmts):
- recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
+ recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
assert recv == stmts, recv
def print_sql():
d = engine.dialect
- return "\n".join(
- str(s.compile(dialect=d))
- for s in engine.mock
- )
-
- engine = create_engine(dialect_name + '://',
- strategy='mock', executor=executor)
- assert not hasattr(engine, 'mock')
+ return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
+
+ engine = create_engine(
+ dialect_name + "://", strategy="mock", executor=executor
+ )
+ assert not hasattr(engine, "mock")
engine.mock = buffer
engine.assert_sql = assert_sql
engine.print_sql = print_sql
@@ -358,14 +357,15 @@ class DBAPIProxyConnection(object):
return getattr(self.conn, key)
-def proxying_engine(conn_cls=DBAPIProxyConnection,
- cursor_cls=DBAPIProxyCursor):
+def proxying_engine(
+ conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor
+):
"""Produce an engine that provides proxy hooks for
common methods.
"""
+
def mock_conn():
return conn_cls(config.db, cursor_cls)
- return testing_engine(options={'creator': mock_conn})
-
+ return testing_engine(options={"creator": mock_conn})