summaryrefslogtreecommitdiff
path: root/test/engine/reconnect.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/engine/reconnect.py')
-rw-r--r--test/engine/reconnect.py215
1 files changed, 199 insertions, 16 deletions
diff --git a/test/engine/reconnect.py b/test/engine/reconnect.py
index 7c213695f..f9d692b3d 100644
--- a/test/engine/reconnect.py
+++ b/test/engine/reconnect.py
@@ -13,37 +13,39 @@ class MockDBAPI(object):
self.connections = weakref.WeakKeyDictionary()
def connect(self, *args, **kwargs):
return MockConnection(self)
-
+ def shutdown(self):
+ for c in self.connections:
+ c.explode[0] = True
+ Error = MockDisconnect
+
class MockConnection(object):
def __init__(self, dbapi):
- self.explode = False
dbapi.connections[self] = True
+ self.explode = [False]
def rollback(self):
pass
def commit(self):
pass
def cursor(self):
- return MockCursor(explode=self.explode)
+ return MockCursor(self)
def close(self):
pass
class MockCursor(object):
- def __init__(self, explode):
- self.explode = explode
+ def __init__(self, parent):
+ self.explode = parent.explode
self.description = None
def execute(self, *args, **kwargs):
- if self.explode:
+ if self.explode[0]:
raise MockDisconnect("Lost the DB connection")
else:
return
def close(self):
pass
-class ReconnectTest(PersistTest):
- def test_reconnect(self):
- """test that an 'is_disconnect' condition will invalidate the connection, and additionally
- dispose the previous connection pool and recreate."""
-
+class MockReconnectTest(PersistTest):
+ def setUp(self):
+ global db, dbapi
dbapi = MockDBAPI()
# create engine using our current dburi
@@ -52,6 +54,11 @@ class ReconnectTest(PersistTest):
# monkeypatch disconnect checker
db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
+ def test_reconnect(self):
+ """test that an 'is_disconnect' condition will invalidate the connection, and additionally
+ dispose the previous connection pool and recreate."""
+
+
pid = id(db.pool)
# make a connection
@@ -68,17 +75,17 @@ class ReconnectTest(PersistTest):
assert len(dbapi.connections) == 2
# set it to fail
- conn.connection.connection.explode = True
-
+ dbapi.shutdown()
+
try:
- # execute should fail
conn.execute("SELECT 1")
assert False
- except exceptions.SQLAlchemyError, e:
+ except exceptions.DBAPIError:
pass
# assert was invalidated
- assert conn.connection.connection is None
+ assert not conn.closed
+ assert conn.invalidated
# close shouldnt break
conn.close()
@@ -92,6 +99,182 @@ class ReconnectTest(PersistTest):
conn.execute("SELECT 1")
conn.close()
assert len(dbapi.connections) == 1
+
+ def test_invalidate_trans(self):
+ conn = db.connect()
+ trans = conn.begin()
+ dbapi.shutdown()
+
+ try:
+ conn.execute("SELECT 1")
+ assert False
+ except exceptions.DBAPIError:
+ pass
+
+ # assert was invalidated
+ assert len(dbapi.connections) == 0
+ assert not conn.closed
+ assert conn.invalidated
+ assert trans.is_active
+
+ try:
+ conn.execute("SELECT 1")
+ assert False
+ except exceptions.InvalidRequestError, e:
+ assert str(e) == "Can't reconnect until invalid transaction is rolled back"
+
+ assert trans.is_active
+
+ try:
+ trans.commit()
+ assert False
+ except exceptions.InvalidRequestError, e:
+ assert str(e) == "Can't reconnect until invalid transaction is rolled back"
+
+ assert trans.is_active
+
+ trans.rollback()
+ assert not trans.is_active
+
+ conn.execute("SELECT 1")
+ assert not conn.invalidated
+
+ assert len(dbapi.connections) == 1
+
+ def test_conn_reusable(self):
+ conn = db.connect()
+
+ conn.execute("SELECT 1")
+
+ assert len(dbapi.connections) == 1
+
+ dbapi.shutdown()
+
+ # raises error
+ try:
+ conn.execute("SELECT 1")
+ assert False
+ except exceptions.DBAPIError:
+ pass
+
+ assert not conn.closed
+ assert conn.invalidated
+
+ # ensure all connections closed (pool was recycled)
+ assert len(dbapi.connections) == 0
+
+ # test reconnects
+ conn.execute("SELECT 1")
+ assert not conn.invalidated
+ assert len(dbapi.connections) == 1
+
+
+class RealReconnectTest(PersistTest):
+ def setUp(self):
+ global engine
+ engine = engines.reconnecting_engine()
+
+ def tearDown(self):
+ engine.dispose()
+
+ def test_reconnect(self):
+ conn = engine.connect()
+
+ self.assertEquals(conn.execute("SELECT 1").scalar(), 1)
+ assert not conn.closed
+
+ engine.test_shutdown()
+
+ try:
+ conn.execute("SELECT 1")
+ assert False
+ except exceptions.DBAPIError, e:
+ if not e.connection_invalidated:
+ raise
+
+ assert not conn.closed
+ assert conn.invalidated
+
+ assert conn.invalidated
+ self.assertEquals(conn.execute("SELECT 1").scalar(), 1)
+ assert not conn.invalidated
+
+ # one more time
+ engine.test_shutdown()
+ try:
+ conn.execute("SELECT 1")
+ assert False
+ except exceptions.DBAPIError, e:
+ if not e.connection_invalidated:
+ raise
+ assert conn.invalidated
+ self.assertEquals(conn.execute("SELECT 1").scalar(), 1)
+ assert not conn.invalidated
+
+ conn.close()
+
+ def test_close(self):
+ conn = engine.connect()
+ self.assertEquals(conn.execute("SELECT 1").scalar(), 1)
+ assert not conn.closed
+
+ engine.test_shutdown()
+
+ try:
+ conn.execute("SELECT 1")
+ assert False
+ except exceptions.DBAPIError, e:
+ if not e.connection_invalidated:
+ raise
+
+ conn.close()
+ conn = engine.connect()
+ self.assertEquals(conn.execute("SELECT 1").scalar(), 1)
+
+ def test_with_transaction(self):
+ conn = engine.connect()
+
+ trans = conn.begin()
+
+ self.assertEquals(conn.execute("SELECT 1").scalar(), 1)
+ assert not conn.closed
+
+ engine.test_shutdown()
+
+ try:
+ conn.execute("SELECT 1")
+ assert False
+ except exceptions.DBAPIError, e:
+ if not e.connection_invalidated:
+ raise
+
+ assert not conn.closed
+ assert conn.invalidated
+ assert trans.is_active
+
+ try:
+ conn.execute("SELECT 1")
+ assert False
+ except exceptions.InvalidRequestError, e:
+ assert str(e) == "Can't reconnect until invalid transaction is rolled back"
+
+ assert trans.is_active
+
+ try:
+ trans.commit()
+ assert False
+ except exceptions.InvalidRequestError, e:
+ assert str(e) == "Can't reconnect until invalid transaction is rolled back"
+
+ assert trans.is_active
+
+ trans.rollback()
+ assert not trans.is_active
+
+ assert conn.invalidated
+ self.assertEquals(conn.execute("SELECT 1").scalar(), 1)
+ assert not conn.invalidated
+
if __name__ == '__main__':
testbase.main()