diff options
Diffstat (limited to 'test/engine/test_reconnect.py')
-rw-r--r-- | test/engine/test_reconnect.py | 443 |
1 files changed, 218 insertions, 225 deletions
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index ee3ff1459..86003bec6 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1,7 +1,6 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message import time -import weakref -from sqlalchemy import select, MetaData, Integer, String, pool, create_engine +from sqlalchemy import select, MetaData, Integer, String, create_engine, pool from sqlalchemy.testing.schema import Table, Column import sqlalchemy as tsa from sqlalchemy import testing @@ -10,6 +9,8 @@ from sqlalchemy.testing.util import gc_collect from sqlalchemy import exc, util from sqlalchemy.testing import fixtures from sqlalchemy.testing.engines import testing_engine +from sqlalchemy.testing import is_not_ +from sqlalchemy.testing.mock import Mock, call class MockError(Exception): pass @@ -17,93 +18,103 @@ class MockError(Exception): class MockDisconnect(MockError): pass -class MockDBAPI(object): - def __init__(self): - self.paramstyle = 'named' - self.connections = weakref.WeakKeyDictionary() - def connect(self, *args, **kwargs): - return MockConnection(self) - def shutdown(self, explode='execute'): - for c in self.connections: - c.explode = explode - Error = MockError - -class MockConnection(object): - def __init__(self, dbapi): - dbapi.connections[self] = True - self.explode = "" - def rollback(self): - if self.explode == 'rollback': +def mock_connection(): + def mock_cursor(): + def execute(*args, **kwargs): + if conn.explode == 'execute': + raise MockDisconnect("Lost the DB connection on execute") + elif conn.explode in ('execute_no_disconnect', ): + raise MockError( + "something broke on execute but we didn't lose the connection") + elif conn.explode in ('rollback', 'rollback_no_disconnect'): + raise MockError( + "something broke on execute but we didn't lose the connection") + elif args and "SELECT" in args[0]: + cursor.description = [('foo', None, None, None, None, None)] + else: + return + + def close(): + cursor.fetchall = cursor.fetchone = \ + Mock(side_effect=MockError("cursor closed")) + cursor = Mock( + execute=Mock(side_effect=execute), + close=Mock(side_effect=close) + ) + return cursor + + def cursor(): + while True: + yield mock_cursor() + + def rollback(): + if conn.explode == 'rollback': raise MockDisconnect("Lost the DB connection on rollback") - if self.explode == 'rollback_no_disconnect': + if conn.explode == 'rollback_no_disconnect': raise MockError( "something broke on rollback but we didn't lose the connection") else: return - def commit(self): - pass - def cursor(self): - return MockCursor(self) - def close(self): - pass - -class MockCursor(object): - def __init__(self, parent): - self.explode = parent.explode - self.description = () - self.closed = False - def execute(self, *args, **kwargs): - if self.explode == 'execute': - raise MockDisconnect("Lost the DB connection on execute") - elif self.explode in ('execute_no_disconnect', ): - raise MockError( - "something broke on execute but we didn't lose the connection") - elif self.explode in ('rollback', 'rollback_no_disconnect'): - raise MockError( - "something broke on execute but we didn't lose the connection") - elif args and "select" in args[0]: - self.description = [('foo', None, None, None, None, None)] - else: - return - def fetchall(self): - if self.closed: - raise MockError("cursor closed") - return [] - def fetchone(self): - if self.closed: - raise MockError("cursor closed") - return None - def close(self): - self.closed = True - -db, dbapi = None, None + + conn = Mock( + rollback=Mock(side_effect=rollback), + cursor=Mock(side_effect=cursor()) + ) + return conn + +def MockDBAPI(): + connections = [] + def connect(): + while True: + conn = mock_connection() + connections.append(conn) + yield conn + + def shutdown(explode='execute'): + for c in connections: + c.explode = explode + + def dispose(): + for c in connections: + c.explode = None + connections[:] = [] + + return Mock( + connect=Mock(side_effect=connect()), + shutdown=Mock(side_effect=shutdown), + dispose=Mock(side_effect=dispose), + paramstyle='named', + connections=connections, + Error=MockError + ) + + class MockReconnectTest(fixtures.TestBase): def setup(self): - global db, dbapi - dbapi = MockDBAPI() + self.dbapi = MockDBAPI() - # note - using straight create_engine here - # since we are testing gc - db = create_engine( + self.db = testing_engine( 'postgresql://foo:bar@localhost/test', - module=dbapi, _initialize=False) + options=dict(module=self.dbapi, _initialize=False)) + self.mock_connect = call(host='localhost', password='bar', + user='foo', database='test') # monkeypatch disconnect checker - db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect) + self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect) def teardown(self): - db.dispose() + self.dbapi.dispose() def test_reconnect(self): """test that an 'is_disconnect' condition will invalidate the connection, and additionally dispose the previous connection pool and recreate.""" - pid = id(db.pool) + db_pool = self.db.pool # make a connection - conn = db.connect() + conn = self.db.connect() # connection works @@ -112,21 +123,20 @@ class MockReconnectTest(fixtures.TestBase): # create a second connection within the pool, which we'll ensure # also goes away - conn2 = db.connect() + conn2 = self.db.connect() conn2.close() # two connections opened total now - assert len(dbapi.connections) == 2 + assert len(self.dbapi.connections) == 2 # set it to fail - dbapi.shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError: - pass + self.dbapi.shutdown() + assert_raises( + tsa.exc.DBAPIError, + conn.execute, select([1]) + ) # assert was invalidated @@ -136,31 +146,38 @@ class MockReconnectTest(fixtures.TestBase): # close shouldnt break conn.close() - assert id(db.pool) != pid + is_not_(self.db.pool, db_pool) # ensure all connections closed (pool was recycled) - gc_collect() - assert len(dbapi.connections) == 0 - conn = db.connect() + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], [call()]] + ) + + conn = self.db.connect() conn.execute(select([1])) conn.close() - assert len(dbapi.connections) == 1 + + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], [call()], []] + ) def test_invalidate_trans(self): - conn = db.connect() + conn = self.db.connect() trans = conn.begin() - dbapi.shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError: - pass + self.dbapi.shutdown() - # assert was invalidated + assert_raises( + tsa.exc.DBAPIError, + conn.execute, select([1]) + ) - gc_collect() - assert len(dbapi.connections) == 0 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()]] + ) assert not conn.closed assert conn.invalidated assert trans.is_active @@ -170,28 +187,35 @@ class MockReconnectTest(fixtures.TestBase): conn.execute, select([1]) ) assert trans.is_active - try: - trans.commit() - assert False - except tsa.exc.InvalidRequestError as e: - assert str(e) \ - == "Can't reconnect until invalid transaction is "\ - "rolled back" + + assert_raises_message( + tsa.exc.InvalidRequestError, + "Can't reconnect until invalid transaction is " + "rolled back", + trans.commit + ) + assert trans.is_active trans.rollback() assert not trans.is_active conn.execute(select([1])) assert not conn.invalidated - assert len(dbapi.connections) == 1 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], []] + ) def test_conn_reusable(self): - conn = db.connect() + conn = self.db.connect() conn.execute(select([1])) - assert len(dbapi.connections) == 1 + eq_( + self.dbapi.connect.mock_calls, + [self.mock_connect] + ) - dbapi.shutdown() + self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, @@ -201,19 +225,24 @@ class MockReconnectTest(fixtures.TestBase): assert not conn.closed assert conn.invalidated - # ensure all connections closed (pool was recycled) - gc_collect() - assert len(dbapi.connections) == 0 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()]] + ) # test reconnects conn.execute(select([1])) assert not conn.invalidated - assert len(dbapi.connections) == 1 + + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], []] + ) def test_invalidated_close(self): - conn = db.connect() + conn = self.db.connect() - dbapi.shutdown() + self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, @@ -230,9 +259,9 @@ class MockReconnectTest(fixtures.TestBase): ) def test_noreconnect_execute_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("execute_no_disconnect") + self.dbapi.shutdown("execute_no_disconnect") # raises error assert_raises_message( @@ -245,9 +274,9 @@ class MockReconnectTest(fixtures.TestBase): assert not conn.invalidated def test_noreconnect_rollback_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("rollback_no_disconnect") + self.dbapi.shutdown("rollback_no_disconnect") # raises error assert_raises_message( @@ -266,13 +295,13 @@ class MockReconnectTest(fixtures.TestBase): ) def test_reconnect_on_reentrant(self): - conn = db.connect() + conn = self.db.connect() conn.execute(select([1])) - assert len(dbapi.connections) == 1 + assert len(self.dbapi.connections) == 1 - dbapi.shutdown("rollback") + self.dbapi.shutdown("rollback") # raises error assert_raises_message( @@ -285,9 +314,9 @@ class MockReconnectTest(fixtures.TestBase): assert conn.invalidated def test_reconnect_on_reentrant_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("rollback") + self.dbapi.shutdown("rollback") # raises error assert_raises_message( @@ -306,10 +335,11 @@ class MockReconnectTest(fixtures.TestBase): ) def test_check_disconnect_no_cursor(self): - conn = db.connect() - result = conn.execute("select 1") + conn = self.db.connect() + result = conn.execute(select([1])) result.cursor.close() conn.close() + assert_raises_message( tsa.exc.DBAPIError, "cursor closed", @@ -319,60 +349,59 @@ class MockReconnectTest(fixtures.TestBase): class CursorErrTest(fixtures.TestBase): def setup(self): - global db, dbapi - - class MDBAPI(MockDBAPI): - def connect(self, *args, **kwargs): - return MConn(self) - - class MConn(MockConnection): - def cursor(self): - return MCursor(self) + def MockDBAPI(): + def cursor(): + while True: + yield Mock( + description=[], + close=Mock(side_effect=Exception("explode"))) + def connect(): + while True: + yield Mock(cursor=Mock(side_effect=cursor())) + + return Mock(connect=Mock(side_effect=connect())) - class MCursor(MockCursor): - def close(self): - raise Exception("explode") - - dbapi = MDBAPI() - - db = testing_engine( + dbapi = MockDBAPI() + self.db = testing_engine( 'postgresql://foo:bar@localhost/test', options=dict(module=dbapi, _initialize=False)) def test_cursor_explode(self): - conn = db.connect() + conn = self.db.connect() result = conn.execute("select foo") result.close() conn.close() def teardown(self): - db.dispose() + self.db.dispose() + + +def _assert_invalidated(fn, *args): + try: + fn(*args) + assert False + except tsa.exc.DBAPIError as e: + if not e.connection_invalidated: + raise -engine = None class RealReconnectTest(fixtures.TestBase): def setup(self): - global engine - engine = engines.reconnecting_engine() + self.engine = engines.reconnecting_engine() def teardown(self): - engine.dispose() + self.engine.dispose() @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_reconnect(self): - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() + self.engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated @@ -382,13 +411,9 @@ class RealReconnectTest(fixtures.TestBase): assert not conn.invalidated # one more time - engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(conn.execute, select([1])) + assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated @@ -396,30 +421,22 @@ class RealReconnectTest(fixtures.TestBase): conn.close() def test_multiple_invalidate(self): - c1 = engine.connect() - c2 = engine.connect() + c1 = self.engine.connect() + c2 = self.engine.connect() eq_(c1.execute(select([1])).scalar(), 1) - p1 = engine.pool - engine.test_shutdown() + p1 = self.engine.pool + self.engine.test_shutdown() - try: - c1.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - assert e.connection_invalidated + _assert_invalidated(c1.execute, select([1])) - p2 = engine.pool + p2 = self.engine.pool - try: - c2.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - assert e.connection_invalidated + _assert_invalidated(c2.execute, select([1])) # pool isn't replaced - assert engine.pool is p2 + assert self.engine.pool is p2 def test_ensure_is_disconnect_gets_connection(self): @@ -430,37 +447,37 @@ class RealReconnectTest(fixtures.TestBase): # though MySQLdb we get a non-working cursor. # assert cursor is None - engine.dialect.is_disconnect = is_disconnect - conn = engine.connect() - engine.test_shutdown() + self.engine.dialect.is_disconnect = is_disconnect + conn = self.engine.connect() + self.engine.test_shutdown() assert_raises( tsa.exc.DBAPIError, conn.execute, select([1]) ) def test_rollback_on_invalid_plain(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() conn.invalidate() trans.rollback() @testing.requires.two_phase_transactions def test_rollback_on_invalid_twophase(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin_twophase() conn.invalidate() trans.rollback() @testing.requires.savepoints def test_rollback_on_invalid_savepoint(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() trans2 = conn.begin_nested() conn.invalidate() trans2.rollback() def test_invalidate_twice(self): - conn = engine.connect() + conn = self.engine.connect() conn.invalidate() conn.invalidate() @@ -503,12 +520,7 @@ class RealReconnectTest(fixtures.TestBase): eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) @@ -517,37 +529,27 @@ class RealReconnectTest(fixtures.TestBase): @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_close(self): - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() + self.engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) conn.close() - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_with_transaction(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated assert trans.is_active @@ -558,13 +560,11 @@ class RealReconnectTest(fixtures.TestBase): conn.execute, select([1]) ) assert trans.is_active - try: - trans.commit() - assert False - except tsa.exc.InvalidRequestError as e: - assert str(e) \ - == "Can't reconnect until invalid transaction is "\ - "rolled back" + assert_raises_message( + tsa.exc.InvalidRequestError, + "Can't reconnect until invalid transaction is rolled back", + trans.commit + ) assert trans.is_active trans.rollback() assert not trans.is_active @@ -602,23 +602,21 @@ class RecycleTest(fixtures.TestBase): eq_(conn.execute(select([1])).scalar(), 1) conn.close() -meta, table, engine = None, None, None class InvalidateDuringResultTest(fixtures.TestBase): def setup(self): - global meta, table, engine - engine = engines.reconnecting_engine() - meta = MetaData(engine) - table = Table('sometable', meta, + self.engine = engines.reconnecting_engine() + self.meta = MetaData(self.engine) + table = Table('sometable', self.meta, Column('id', Integer, primary_key=True), Column('name', String(50))) - meta.create_all() + self.meta.create_all() table.insert().execute( - [{'id':i, 'name':'row %d' % i} for i in range(1, 100)] + [{'id': i, 'name': 'row %d' % i} for i in range(1, 100)] ) def teardown(self): - meta.drop_all() - engine.dispose() + self.meta.drop_all() + self.engine.dispose() @testing.fails_if([ '+mysqlconnector', '+mysqldb', @@ -628,16 +626,11 @@ class InvalidateDuringResultTest(fixtures.TestBase): @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_invalidate_on_results(self): - conn = engine.connect() + conn = self.engine.connect() result = conn.execute('select * from sometable') for x in range(20): result.fetchone() - engine.test_shutdown() - try: - print('ghost result: %r' % result.fetchone()) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(result.fetchone) assert conn.invalidated |