summaryrefslogtreecommitdiff
path: root/test/engine/test_pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/engine/test_pool.py')
-rw-r--r--test/engine/test_pool.py103
1 files changed, 93 insertions, 10 deletions
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py
index fc6f3dcea..bbab0a7c3 100644
--- a/test/engine/test_pool.py
+++ b/test/engine/test_pool.py
@@ -7,7 +7,7 @@ from sqlalchemy.testing.util import gc_collect, lazy_gc
from sqlalchemy.testing import eq_, assert_raises, is_not_
from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing import fixtures
-
+import random
from sqlalchemy.testing.mock import Mock, call
join_timeout = 10
@@ -1069,7 +1069,8 @@ class QueuePoolTest(PoolTestBase):
# inside the queue, before we invalidate the other
# two conns
time.sleep(.2)
- p2 = p._replace()
+ p._invalidate(c2)
+ c2.invalidate()
for t in threads:
t.join(join_timeout)
@@ -1079,19 +1080,18 @@ class QueuePoolTest(PoolTestBase):
@testing.requires.threading_with_mock
def test_notify_waiters(self):
dbapi = MockDBAPI()
+
canary = []
- def creator1():
+ def creator():
canary.append(1)
return dbapi.connect()
- def creator2():
- canary.append(2)
- return dbapi.connect()
- p1 = pool.QueuePool(creator=creator1,
+ p1 = pool.QueuePool(creator=creator,
pool_size=1, timeout=None,
max_overflow=0)
- p2 = pool.NullPool(creator=creator2)
+ #p2 = pool.NullPool(creator=creator2)
def waiter(p):
conn = p.connect()
+ canary.append(2)
time.sleep(.5)
conn.close()
@@ -1104,12 +1104,14 @@ class QueuePoolTest(PoolTestBase):
threads.append(t)
time.sleep(.5)
eq_(canary, [1])
- p1._pool.abort(p2)
+
+ c1.invalidate()
+ p1._invalidate(c1)
for t in threads:
t.join(join_timeout)
- eq_(canary, [1, 2, 2, 2, 2, 2])
+ eq_(canary, [1, 1, 2, 2, 2, 2, 2])
def test_dispose_closes_pooled(self):
dbapi = MockDBAPI()
@@ -1251,6 +1253,21 @@ class QueuePoolTest(PoolTestBase):
c3 = p.connect()
assert id(c3.connection) != c_id
+ def test_recycle_on_invalidate(self):
+ p = self._queuepool_fixture(pool_size=1,
+ max_overflow=0)
+ c1 = p.connect()
+ c_id = id(c1.connection)
+ c1.close()
+ c2 = p.connect()
+ assert id(c2.connection) == c_id
+
+ p._invalidate(c2)
+ c2.close()
+ time.sleep(.5)
+ c3 = p.connect()
+ assert id(c3.connection) != c_id
+
def _assert_cleanup_on_pooled_reconnect(self, dbapi, p):
# p is QueuePool with size=1, max_overflow=2,
# and one connection in the pool that will need to
@@ -1290,6 +1307,72 @@ class QueuePoolTest(PoolTestBase):
time.sleep(1)
self._assert_cleanup_on_pooled_reconnect(dbapi, p)
+ def test_recycle_pool_no_race(self):
+ def slow_close():
+ slow_closing_connection._slow_close()
+ time.sleep(.5)
+
+ slow_closing_connection = Mock()
+ slow_closing_connection.connect.return_value.close = slow_close
+
+ class Error(Exception):
+ pass
+
+ dialect = Mock()
+ dialect.is_disconnect = lambda *arg, **kw: True
+ dialect.dbapi.Error = Error
+
+ pools = []
+ class TrackQueuePool(pool.QueuePool):
+ def __init__(self, *arg, **kw):
+ pools.append(self)
+ super(TrackQueuePool, self).__init__(*arg, **kw)
+
+ def creator():
+ return slow_closing_connection.connect()
+ p1 = TrackQueuePool(creator=creator, pool_size=20)
+
+ from sqlalchemy import create_engine
+ eng = create_engine("postgresql://", pool=p1, _initialize=False)
+ eng.dialect = dialect
+
+ # 15 total connections
+ conns = [eng.connect() for i in range(15)]
+
+ # return 8 back to the pool
+ for conn in conns[3:10]:
+ conn.close()
+
+ def attempt(conn):
+ time.sleep(random.random())
+ try:
+ conn._handle_dbapi_exception(Error(), "statement", {}, Mock(), Mock())
+ except tsa.exc.DBAPIError:
+ pass
+
+ # run an error + invalidate operation on the remaining 7 open connections
+ threads = []
+ for conn in conns:
+ t = threading.Thread(target=attempt, args=(conn, ))
+ t.start()
+ threads.append(t)
+
+ for t in threads:
+ t.join()
+
+ # return all 15 connections to the pool
+ for conn in conns:
+ conn.close()
+
+ # re-open 15 total connections
+ conns = [eng.connect() for i in range(15)]
+
+ # 15 connections have been fully closed due to invalidate
+ assert slow_closing_connection._slow_close.call_count == 15
+
+ # 15 initial connections + 15 reconnections
+ assert slow_closing_connection.connect.call_count == 30
+ assert len(pools) <= 2, len(pools)
def test_invalidate(self):
p = self._queuepool_fixture(pool_size=1, max_overflow=0)