diff options
Diffstat (limited to 'test/engine/test_pool.py')
-rw-r--r-- | test/engine/test_pool.py | 103 |
1 files changed, 93 insertions, 10 deletions
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index fc6f3dcea..bbab0a7c3 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -7,7 +7,7 @@ from sqlalchemy.testing.util import gc_collect, lazy_gc from sqlalchemy.testing import eq_, assert_raises, is_not_ from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing import fixtures - +import random from sqlalchemy.testing.mock import Mock, call join_timeout = 10 @@ -1069,7 +1069,8 @@ class QueuePoolTest(PoolTestBase): # inside the queue, before we invalidate the other # two conns time.sleep(.2) - p2 = p._replace() + p._invalidate(c2) + c2.invalidate() for t in threads: t.join(join_timeout) @@ -1079,19 +1080,18 @@ class QueuePoolTest(PoolTestBase): @testing.requires.threading_with_mock def test_notify_waiters(self): dbapi = MockDBAPI() + canary = [] - def creator1(): + def creator(): canary.append(1) return dbapi.connect() - def creator2(): - canary.append(2) - return dbapi.connect() - p1 = pool.QueuePool(creator=creator1, + p1 = pool.QueuePool(creator=creator, pool_size=1, timeout=None, max_overflow=0) - p2 = pool.NullPool(creator=creator2) + #p2 = pool.NullPool(creator=creator2) def waiter(p): conn = p.connect() + canary.append(2) time.sleep(.5) conn.close() @@ -1104,12 +1104,14 @@ class QueuePoolTest(PoolTestBase): threads.append(t) time.sleep(.5) eq_(canary, [1]) - p1._pool.abort(p2) + + c1.invalidate() + p1._invalidate(c1) for t in threads: t.join(join_timeout) - eq_(canary, [1, 2, 2, 2, 2, 2]) + eq_(canary, [1, 1, 2, 2, 2, 2, 2]) def test_dispose_closes_pooled(self): dbapi = MockDBAPI() @@ -1251,6 +1253,21 @@ class QueuePoolTest(PoolTestBase): c3 = p.connect() assert id(c3.connection) != c_id + def test_recycle_on_invalidate(self): + p = self._queuepool_fixture(pool_size=1, + max_overflow=0) + c1 = p.connect() + c_id = id(c1.connection) + c1.close() + c2 = p.connect() + assert id(c2.connection) == c_id + + p._invalidate(c2) + c2.close() + time.sleep(.5) + c3 = p.connect() + assert id(c3.connection) != c_id + def _assert_cleanup_on_pooled_reconnect(self, dbapi, p): # p is QueuePool with size=1, max_overflow=2, # and one connection in the pool that will need to @@ -1290,6 +1307,72 @@ class QueuePoolTest(PoolTestBase): time.sleep(1) self._assert_cleanup_on_pooled_reconnect(dbapi, p) + def test_recycle_pool_no_race(self): + def slow_close(): + slow_closing_connection._slow_close() + time.sleep(.5) + + slow_closing_connection = Mock() + slow_closing_connection.connect.return_value.close = slow_close + + class Error(Exception): + pass + + dialect = Mock() + dialect.is_disconnect = lambda *arg, **kw: True + dialect.dbapi.Error = Error + + pools = [] + class TrackQueuePool(pool.QueuePool): + def __init__(self, *arg, **kw): + pools.append(self) + super(TrackQueuePool, self).__init__(*arg, **kw) + + def creator(): + return slow_closing_connection.connect() + p1 = TrackQueuePool(creator=creator, pool_size=20) + + from sqlalchemy import create_engine + eng = create_engine("postgresql://", pool=p1, _initialize=False) + eng.dialect = dialect + + # 15 total connections + conns = [eng.connect() for i in range(15)] + + # return 8 back to the pool + for conn in conns[3:10]: + conn.close() + + def attempt(conn): + time.sleep(random.random()) + try: + conn._handle_dbapi_exception(Error(), "statement", {}, Mock(), Mock()) + except tsa.exc.DBAPIError: + pass + + # run an error + invalidate operation on the remaining 7 open connections + threads = [] + for conn in conns: + t = threading.Thread(target=attempt, args=(conn, )) + t.start() + threads.append(t) + + for t in threads: + t.join() + + # return all 15 connections to the pool + for conn in conns: + conn.close() + + # re-open 15 total connections + conns = [eng.connect() for i in range(15)] + + # 15 connections have been fully closed due to invalidate + assert slow_closing_connection._slow_close.call_count == 15 + + # 15 initial connections + 15 reconnections + assert slow_closing_connection.connect.call_count == 30 + assert len(pools) <= 2, len(pools) def test_invalidate(self): p = self._queuepool_fixture(pool_size=1, max_overflow=0) |