diff options
Diffstat (limited to 'test/lib/engines.py')
-rw-r--r-- | test/lib/engines.py | 56 |
1 files changed, 39 insertions, 17 deletions
diff --git a/test/lib/engines.py b/test/lib/engines.py index 4794a5fab..3a5132b8b 100644 --- a/test/lib/engines.py +++ b/test/lib/engines.py @@ -3,37 +3,55 @@ from collections import deque from test.bootstrap import config from test.lib.util import decorator from sqlalchemy.util import callable -from sqlalchemy import event +from sqlalchemy import event, pool +from sqlalchemy.engine import base as engine_base import re import warnings class ConnectionKiller(object): def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() + self.testing_engines = weakref.WeakKeyDictionary() + self.conns = set() + + def add_engine(self, engine): + self.testing_engines[engine] = True def checkout(self, dbapi_con, con_record, con_proxy): self.proxy_refs[con_proxy] = True + self.conns.add(dbapi_con) + + def _safe(self, fn): + try: + fn() + except (SystemExit, KeyboardInterrupt): + raise + except Exception, e: + warnings.warn( + "testing_reaper couldn't " + "rollback/close connection: %s" % e) - def _apply_all(self, methods): - # must copy keys atomically + def rollback_all(self): for rec in self.proxy_refs.keys(): if rec is not None and rec.is_valid: - try: - for name in methods: - if callable(name): - name(rec) - else: - getattr(rec, name)() - except (SystemExit, KeyboardInterrupt): - raise - except Exception, e: - warnings.warn("testing_reaper couldn't close connection: %s" % e) - - def rollback_all(self): - self._apply_all(('rollback',)) + self._safe(rec.rollback) def close_all(self): - self._apply_all(('rollback', 'close')) + for rec in self.proxy_refs.keys(): + if rec is not None: + self._safe(rec._close) + + def _after_test_ctx(self): + for conn in self.conns: + self._safe(conn.rollback) + + def _stop_test_ctx(self): + self.close_all() + for conn in self.conns: + self._safe(conn.close) + self.conns = set() + for rec in self.testing_engines.keys(): + rec.dispose() def assert_all_closed(self): for rec in self.proxy_refs: @@ -134,6 +152,9 @@ def testing_engine(url=None, options=None): options = options or config.db_opts engine = create_engine(url, **options) + if isinstance(engine.pool, pool.QueuePool): + engine.pool._timeout = 0 + engine.pool._max_overflow = 0 event.listen(engine, 'after_execute', asserter.execute) event.listen(engine, 'after_cursor_execute', asserter.cursor_execute) event.listen(engine.pool, 'checkout', testing_reaper.checkout) @@ -141,6 +162,7 @@ def testing_engine(url=None, options=None): # may want to call this, results # in first-connect initializers #engine.connect() + testing_reaper.add_engine(engine) return engine |