diff options
Diffstat (limited to 'test/engine/reconnect.py')
-rw-r--r-- | test/engine/reconnect.py | 215 |
1 files changed, 199 insertions, 16 deletions
diff --git a/test/engine/reconnect.py b/test/engine/reconnect.py index 7c213695f..f9d692b3d 100644 --- a/test/engine/reconnect.py +++ b/test/engine/reconnect.py @@ -13,37 +13,39 @@ class MockDBAPI(object): self.connections = weakref.WeakKeyDictionary() def connect(self, *args, **kwargs): return MockConnection(self) - + def shutdown(self): + for c in self.connections: + c.explode[0] = True + Error = MockDisconnect + class MockConnection(object): def __init__(self, dbapi): - self.explode = False dbapi.connections[self] = True + self.explode = [False] def rollback(self): pass def commit(self): pass def cursor(self): - return MockCursor(explode=self.explode) + return MockCursor(self) def close(self): pass class MockCursor(object): - def __init__(self, explode): - self.explode = explode + def __init__(self, parent): + self.explode = parent.explode self.description = None def execute(self, *args, **kwargs): - if self.explode: + if self.explode[0]: raise MockDisconnect("Lost the DB connection") else: return def close(self): pass -class ReconnectTest(PersistTest): - def test_reconnect(self): - """test that an 'is_disconnect' condition will invalidate the connection, and additionally - dispose the previous connection pool and recreate.""" - +class MockReconnectTest(PersistTest): + def setUp(self): + global db, dbapi dbapi = MockDBAPI() # create engine using our current dburi @@ -52,6 +54,11 @@ class ReconnectTest(PersistTest): # monkeypatch disconnect checker db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect) + def test_reconnect(self): + """test that an 'is_disconnect' condition will invalidate the connection, and additionally + dispose the previous connection pool and recreate.""" + + pid = id(db.pool) # make a connection @@ -68,17 +75,17 @@ class ReconnectTest(PersistTest): assert len(dbapi.connections) == 2 # set it to fail - conn.connection.connection.explode = True - + dbapi.shutdown() + try: - # execute should fail conn.execute("SELECT 1") assert False - except exceptions.SQLAlchemyError, e: + except exceptions.DBAPIError: pass # assert was invalidated - assert conn.connection.connection is None + assert not conn.closed + assert conn.invalidated # close shouldnt break conn.close() @@ -92,6 +99,182 @@ class ReconnectTest(PersistTest): conn.execute("SELECT 1") conn.close() assert len(dbapi.connections) == 1 + + def test_invalidate_trans(self): + conn = db.connect() + trans = conn.begin() + dbapi.shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError: + pass + + # assert was invalidated + assert len(dbapi.connections) == 0 + assert not conn.closed + assert conn.invalidated + assert trans.is_active + + try: + conn.execute("SELECT 1") + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + try: + trans.commit() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + trans.rollback() + assert not trans.is_active + + conn.execute("SELECT 1") + assert not conn.invalidated + + assert len(dbapi.connections) == 1 + + def test_conn_reusable(self): + conn = db.connect() + + conn.execute("SELECT 1") + + assert len(dbapi.connections) == 1 + + dbapi.shutdown() + + # raises error + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError: + pass + + assert not conn.closed + assert conn.invalidated + + # ensure all connections closed (pool was recycled) + assert len(dbapi.connections) == 0 + + # test reconnects + conn.execute("SELECT 1") + assert not conn.invalidated + assert len(dbapi.connections) == 1 + + +class RealReconnectTest(PersistTest): + def setUp(self): + global engine + engine = engines.reconnecting_engine() + + def tearDown(self): + engine.dispose() + + def test_reconnect(self): + conn = engine.connect() + + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert not conn.closed + assert conn.invalidated + + assert conn.invalidated + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.invalidated + + # one more time + engine.test_shutdown() + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + assert conn.invalidated + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.invalidated + + conn.close() + + def test_close(self): + conn = engine.connect() + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + conn.close() + conn = engine.connect() + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + + def test_with_transaction(self): + conn = engine.connect() + + trans = conn.begin() + + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert not conn.closed + assert conn.invalidated + assert trans.is_active + + try: + conn.execute("SELECT 1") + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + try: + trans.commit() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + trans.rollback() + assert not trans.is_active + + assert conn.invalidated + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.invalidated + if __name__ == '__main__': testbase.main() |