diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/engine/reconnect.py | 215 | ||||
-rw-r--r-- | test/sql/testtypes.py | 2 | ||||
-rw-r--r-- | test/testlib/engines.py | 28 |
3 files changed, 227 insertions, 18 deletions
diff --git a/test/engine/reconnect.py b/test/engine/reconnect.py index 7c213695f..f9d692b3d 100644 --- a/test/engine/reconnect.py +++ b/test/engine/reconnect.py @@ -13,37 +13,39 @@ class MockDBAPI(object): self.connections = weakref.WeakKeyDictionary() def connect(self, *args, **kwargs): return MockConnection(self) - + def shutdown(self): + for c in self.connections: + c.explode[0] = True + Error = MockDisconnect + class MockConnection(object): def __init__(self, dbapi): - self.explode = False dbapi.connections[self] = True + self.explode = [False] def rollback(self): pass def commit(self): pass def cursor(self): - return MockCursor(explode=self.explode) + return MockCursor(self) def close(self): pass class MockCursor(object): - def __init__(self, explode): - self.explode = explode + def __init__(self, parent): + self.explode = parent.explode self.description = None def execute(self, *args, **kwargs): - if self.explode: + if self.explode[0]: raise MockDisconnect("Lost the DB connection") else: return def close(self): pass -class ReconnectTest(PersistTest): - def test_reconnect(self): - """test that an 'is_disconnect' condition will invalidate the connection, and additionally - dispose the previous connection pool and recreate.""" - +class MockReconnectTest(PersistTest): + def setUp(self): + global db, dbapi dbapi = MockDBAPI() # create engine using our current dburi @@ -52,6 +54,11 @@ class ReconnectTest(PersistTest): # monkeypatch disconnect checker db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect) + def test_reconnect(self): + """test that an 'is_disconnect' condition will invalidate the connection, and additionally + dispose the previous connection pool and recreate.""" + + pid = id(db.pool) # make a connection @@ -68,17 +75,17 @@ class ReconnectTest(PersistTest): assert len(dbapi.connections) == 2 # set it to fail - conn.connection.connection.explode = True - + dbapi.shutdown() + try: - # execute should fail conn.execute("SELECT 1") assert False - except exceptions.SQLAlchemyError, e: + except exceptions.DBAPIError: pass # assert was invalidated - assert conn.connection.connection is None + assert not conn.closed + assert conn.invalidated # close shouldnt break conn.close() @@ -92,6 +99,182 @@ class ReconnectTest(PersistTest): conn.execute("SELECT 1") conn.close() assert len(dbapi.connections) == 1 + + def test_invalidate_trans(self): + conn = db.connect() + trans = conn.begin() + dbapi.shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError: + pass + + # assert was invalidated + assert len(dbapi.connections) == 0 + assert not conn.closed + assert conn.invalidated + assert trans.is_active + + try: + conn.execute("SELECT 1") + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + try: + trans.commit() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + trans.rollback() + assert not trans.is_active + + conn.execute("SELECT 1") + assert not conn.invalidated + + assert len(dbapi.connections) == 1 + + def test_conn_reusable(self): + conn = db.connect() + + conn.execute("SELECT 1") + + assert len(dbapi.connections) == 1 + + dbapi.shutdown() + + # raises error + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError: + pass + + assert not conn.closed + assert conn.invalidated + + # ensure all connections closed (pool was recycled) + assert len(dbapi.connections) == 0 + + # test reconnects + conn.execute("SELECT 1") + assert not conn.invalidated + assert len(dbapi.connections) == 1 + + +class RealReconnectTest(PersistTest): + def setUp(self): + global engine + engine = engines.reconnecting_engine() + + def tearDown(self): + engine.dispose() + + def test_reconnect(self): + conn = engine.connect() + + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert not conn.closed + assert conn.invalidated + + assert conn.invalidated + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.invalidated + + # one more time + engine.test_shutdown() + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + assert conn.invalidated + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.invalidated + + conn.close() + + def test_close(self): + conn = engine.connect() + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + conn.close() + conn = engine.connect() + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + + def test_with_transaction(self): + conn = engine.connect() + + trans = conn.begin() + + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert not conn.closed + assert conn.invalidated + assert trans.is_active + + try: + conn.execute("SELECT 1") + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + try: + trans.commit() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + trans.rollback() + assert not trans.is_active + + assert conn.invalidated + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.invalidated + if __name__ == '__main__': testbase.main() diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 69696ec64..59acaff2f 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -342,7 +342,7 @@ class UnicodeTest(AssertMixin): assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'" finally: unicode_engine.dispose() - + @testing.unsupported('oracle') def testblanks(self): unicode_table.insert().execute(unicode_varchar=u'') diff --git a/test/testlib/engines.py b/test/testlib/engines.py index 56507618c..b576a1536 100644 --- a/test/testlib/engines.py +++ b/test/testlib/engines.py @@ -68,7 +68,31 @@ def close_open_connections(fn): decorated.__name__ = fn.__name__ return decorated - +class ReconnectFixture(object): + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + conn = self.dbapi.connect(*args, **kwargs) + self.connections.append(conn) + return conn + + def shutdown(self): + for c in list(self.connections): + c.close() + self.connections = [] + +def reconnecting_engine(url=None, options=None): + url = url or config.db_url + dbapi = config.db.dialect.dbapi + engine = testing_engine(url, {'module':ReconnectFixture(dbapi)}) + engine.test_shutdown = engine.dialect.dbapi.shutdown + return engine + def testing_engine(url=None, options=None): """Produce an engine configured by --options with optional overrides.""" @@ -109,3 +133,5 @@ def utf8_engine(url=None, options=None): url = str(url) return testing_engine(url, options) + + |