diff options
Diffstat (limited to 'test/engine/test_pool.py')
-rw-r--r-- | test/engine/test_pool.py | 149 |
1 files changed, 85 insertions, 64 deletions
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 583978465..981df6dd0 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -4,37 +4,21 @@ from sqlalchemy import pool, select, event import sqlalchemy as tsa from sqlalchemy import testing from sqlalchemy.testing.util import gc_collect, lazy_gc -from sqlalchemy.testing import eq_, assert_raises +from sqlalchemy.testing import eq_, assert_raises, is_not_ from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing import fixtures -mcid = 1 -class MockDBAPI(object): - throw_error = False - def connect(self, *args, **kwargs): - if self.throw_error: - raise Exception("couldnt connect !") - delay = kwargs.pop('delay', 0) - if delay: - time.sleep(delay) - return MockConnection() -class MockConnection(object): - closed = False - def __init__(self): - global mcid - self.id = mcid - mcid += 1 - def close(self): - self.closed = True - def rollback(self): - pass - def cursor(self): - return MockCursor() -class MockCursor(object): - def execute(self, *args, **kw): - pass - def close(self): - pass +from sqlalchemy.testing.mock import Mock, call + +def MockDBAPI(): + def cursor(): + while True: + yield Mock() + def connect(): + while True: + yield Mock(cursor=Mock(side_effect=cursor())) + + return Mock(connect=Mock(side_effect=connect())) class PoolTestBase(fixtures.TestBase): def setup(self): @@ -71,11 +55,9 @@ class PoolTest(PoolTestBase): assert c4 is not c5 def test_manager_with_key(self): - class NoKws(object): - def connect(self, arg): - return MockConnection() - manager = pool.manage(NoKws(), use_threadlocal=True) + dbapi = MockDBAPI() + manager = pool.manage(dbapi, use_threadlocal=True) c1 = manager.connect('foo.db', sa_pool_key="a") c2 = manager.connect('foo.db', sa_pool_key="b") @@ -83,9 +65,14 @@ class PoolTest(PoolTestBase): assert c1.cursor() is not None assert c1 is not c2 - assert c1 is c3 - + assert c1 is c3 + eq_(dbapi.connect.mock_calls, + [ + call("foo.db"), + call("foo.db"), + ] + ) def test_bad_args(self): @@ -127,7 +114,7 @@ class PoolTest(PoolTestBase): p = cls(creator=mock_dbapi.connect) conn = p.connect() conn.close() - mock_dbapi.throw_error = True + mock_dbapi.connect.side_effect = Exception("error!") p.dispose() p.recreate() @@ -211,9 +198,9 @@ class PoolTest(PoolTestBase): self.assert_('foo2' in c.info) c2 = p.connect() - self.assert_(c.connection is not c2.connection) - self.assert_(not c2.info) - self.assert_('foo2' in c.info) + is_not_(c.connection, c2.connection) + assert not c2.info + assert 'foo2' in c.info class PoolDialectTest(PoolTestBase): @@ -945,19 +932,24 @@ class QueuePoolTest(PoolTestBase): def test_dispose_closes_pooled(self): dbapi = MockDBAPI() - def creator(): - return dbapi.connect() - p = pool.QueuePool(creator=creator, + p = pool.QueuePool(creator=dbapi.connect, pool_size=2, timeout=None, max_overflow=0) c1 = p.connect() c2 = p.connect() - conns = [c1.connection, c2.connection] + c1_con = c1.connection + c2_con = c2.connection + c1.close() - eq_([c.closed for c in conns], [False, False]) + + eq_(c1_con.close.call_count, 0) + eq_(c2_con.close.call_count, 0) + p.dispose() - eq_([c.closed for c in conns], [True, False]) + + eq_(c1_con.close.call_count, 1) + eq_(c2_con.close.call_count, 0) # currently, if a ConnectionFairy is closed # after the pool has been disposed, there's no @@ -965,11 +957,12 @@ class QueuePoolTest(PoolTestBase): # immediately - it just gets returned to the # pool normally... c2.close() - eq_([c.closed for c in conns], [True, False]) + eq_(c1_con.close.call_count, 1) + eq_(c2_con.close.call_count, 0) # ...and that's the one we'll get back next. c3 = p.connect() - assert c3.connection is conns[1] + assert c3.connection is c2_con def test_no_overflow(self): self._test_overflow(40, 0) @@ -1010,12 +1003,21 @@ class QueuePoolTest(PoolTestBase): return c for j in range(5): + # open 4 conns at a time. each time this + # will yield two pooled connections + two + # overflow connections. conns = [_conn() for i in range(4)] for c in conns: c.close() - still_opened = len([c for c in strong_refs if not c.closed]) - eq_(still_opened, 2) + # doing that for a total of 5 times yields + # ten overflow connections closed plus the + # two pooled connections unclosed. + + eq_( + set([c.close.call_count for c in strong_refs]), + set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0]) + ) @testing.requires.predictable_gc def test_weakref_kaboom(self): @@ -1108,18 +1110,30 @@ class QueuePoolTest(PoolTestBase): dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0) c1 = p.connect() c1.detach() - c_id = c1.connection.id - c2 = p.connect() - assert c2.connection.id != c1.connection.id - dbapi.raise_error = True - c2.invalidate() - c2 = None c2 = p.connect() - assert c2.connection.id != c1.connection.id - con = c1.connection - assert not con.closed + eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")]) + + c1_con = c1.connection + assert c1_con is not None + eq_(c1_con.close.call_count, 0) c1.close() - assert con.closed + eq_(c1_con.close.call_count, 1) + + def test_detach_via_invalidate(self): + dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0) + + c1 = p.connect() + c1_con = c1.connection + c1.invalidate() + assert c1.connection is None + eq_(c1_con.close.call_count, 1) + + c2 = p.connect() + assert c2.connection is not c1_con + c2_con = c2.connection + + c2.close() + eq_(c2_con.close.call_count, 0) def test_threadfairy(self): p = self._queuepool_fixture(pool_size=3, max_overflow=-1, use_threadlocal=True) @@ -1141,8 +1155,13 @@ class SingletonThreadPoolTest(PoolTestBase): been called.""" dbapi = MockDBAPI() - p = pool.SingletonThreadPool(creator=dbapi.connect, - pool_size=3) + + lock = threading.Lock() + def creator(): + # the mock iterator isn't threadsafe... + with lock: + return dbapi.connect() + p = pool.SingletonThreadPool(creator=creator, pool_size=3) if strong_refs: sr = set() @@ -1172,7 +1191,7 @@ class SingletonThreadPoolTest(PoolTestBase): assert len(p._all_conns) == 3 if strong_refs: - still_opened = len([c for c in sr if not c.closed]) + still_opened = len([c for c in sr if not c.close.call_count]) eq_(still_opened, 3) class AssertionPoolTest(PoolTestBase): @@ -1198,17 +1217,19 @@ class NullPoolTest(PoolTestBase): dbapi = MockDBAPI() p = pool.NullPool(creator=lambda: dbapi.connect('foo.db')) c1 = p.connect() - c_id = c1.connection.id + c1.close() c1 = None c1 = p.connect() - dbapi.raise_error = True c1.invalidate() c1 = None c1 = p.connect() - assert c1.connection.id != c_id + dbapi.connect.assert_has_calls([ + call('foo.db'), + call('foo.db')], + any_order=True) class StaticPoolTest(PoolTestBase): |