from sqlalchemy.testing import eq_, assert_raises, assert_raises_message import time from sqlalchemy import ( select, MetaData, Integer, String, create_engine, pool, exc, util) from sqlalchemy.testing.schema import Table, Column import sqlalchemy as tsa from sqlalchemy import testing from sqlalchemy.testing import mock from sqlalchemy.testing import engines from sqlalchemy.testing import fixtures from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.mock import Mock, call, patch class MockError(Exception): pass class MockDisconnect(MockError): pass def mock_connection(): def mock_cursor(): def execute(*args, **kwargs): if conn.explode == 'execute': raise MockDisconnect("Lost the DB connection on execute") elif conn.explode in ('execute_no_disconnect', ): raise MockError( "something broke on execute but we didn't lose the " "connection") elif conn.explode in ('rollback', 'rollback_no_disconnect'): raise MockError( "something broke on execute but we didn't lose the " "connection") elif args and "SELECT" in args[0]: cursor.description = [('foo', None, None, None, None, None)] else: return def close(): cursor.fetchall = cursor.fetchone = \ Mock(side_effect=MockError("cursor closed")) cursor = Mock( execute=Mock(side_effect=execute), close=Mock(side_effect=close)) return cursor def cursor(): while True: yield mock_cursor() def rollback(): if conn.explode == 'rollback': raise MockDisconnect("Lost the DB connection on rollback") if conn.explode == 'rollback_no_disconnect': raise MockError( "something broke on rollback but we didn't lose the " "connection") else: return conn = Mock( rollback=Mock(side_effect=rollback), cursor=Mock(side_effect=cursor())) return conn def MockDBAPI(): connections = [] def connect(): while True: conn = mock_connection() connections.append(conn) yield conn def shutdown(explode='execute'): for c in connections: c.explode = explode def dispose(): for c in connections: c.explode = None connections[:] = [] return Mock( connect=Mock(side_effect=connect()), shutdown=Mock(side_effect=shutdown), dispose=Mock(side_effect=dispose), paramstyle='named', connections=connections, Error=MockError) class MockReconnectTest(fixtures.TestBase): def setup(self): self.dbapi = MockDBAPI() self.db = testing_engine( 'postgresql://foo:bar@localhost/test', options=dict(module=self.dbapi, _initialize=False)) self.mock_connect = call( host='localhost', password='bar', user='foo', database='test') # monkeypatch disconnect checker self.db.dialect.is_disconnect = \ lambda e, conn, cursor: isinstance(e, MockDisconnect) def teardown(self): self.dbapi.dispose() def test_reconnect(self): """test that an 'is_disconnect' condition will invalidate the connection, and additionally dispose the previous connection pool and recreate.""" db_pool = self.db.pool # make a connection conn = self.db.connect() # connection works conn.execute(select([1])) # create a second connection within the pool, which we'll ensure # also goes away conn2 = self.db.connect() conn2.close() # two connections opened total now assert len(self.dbapi.connections) == 2 # set it to fail self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, conn.execute, select([1]) ) # assert was invalidated assert not conn.closed assert conn.invalidated # close shouldn't break conn.close() # ensure one connection closed... eq_( [c.close.mock_calls for c in self.dbapi.connections], [[call()], []] ) conn = self.db.connect() eq_( [c.close.mock_calls for c in self.dbapi.connections], [[call()], [call()], []] ) conn.execute(select([1])) conn.close() eq_( [c.close.mock_calls for c in self.dbapi.connections], [[call()], [call()], []] ) def test_invalidate_trans(self): conn = self.db.connect() trans = conn.begin() self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, conn.execute, select([1]) ) eq_( [c.close.mock_calls for c in self.dbapi.connections], [[call()]] ) assert not conn.closed assert conn.invalidated assert trans.is_active assert_raises_message( tsa.exc.StatementError, "Can't reconnect until invalid transaction is rolled back", conn.execute, select([1]) ) assert trans.is_active assert_raises_message( tsa.exc.InvalidRequestError, "Can't reconnect until invalid transaction is rolled back", trans.commit) assert trans.is_active trans.rollback() assert not trans.is_active conn.execute(select([1])) assert not conn.invalidated eq_( [c.close.mock_calls for c in self.dbapi.connections], [[call()], []] ) def test_invalidate_dont_call_finalizer(self): conn = self.db.connect() finalizer = mock.Mock() conn.connection._connection_record.\ finalize_callback.append(finalizer) conn.invalidate() assert conn.invalidated eq_(finalizer.call_count, 0) def test_conn_reusable(self): conn = self.db.connect() conn.execute(select([1])) eq_( self.dbapi.connect.mock_calls, [self.mock_connect] ) self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, conn.execute, select([1]) ) assert not conn.closed assert conn.invalidated eq_( [c.close.mock_calls for c in self.dbapi.connections], [[call()]] ) # test reconnects conn.execute(select([1])) assert not conn.invalidated eq_( [c.close.mock_calls for c in self.dbapi.connections], [[call()], []] ) def test_invalidated_close(self): conn = self.db.connect() self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, conn.execute, select([1]) ) conn.close() assert conn.closed assert conn.invalidated assert_raises_message( tsa.exc.StatementError, "This Connection is closed", conn.execute, select([1]) ) def test_noreconnect_execute_plus_closewresult(self): conn = self.db.connect(close_with_result=True) self.dbapi.shutdown("execute_no_disconnect") # raises error assert_raises_message( tsa.exc.DBAPIError, "something broke on execute but we didn't lose the connection", conn.execute, select([1]) ) assert conn.closed assert not conn.invalidated def test_noreconnect_rollback_plus_closewresult(self): conn = self.db.connect(close_with_result=True) self.dbapi.shutdown("rollback_no_disconnect") # raises error assert_raises_message( tsa.exc.DBAPIError, "something broke on rollback but we didn't lose the connection", conn.execute, select([1]) ) assert conn.closed assert not conn.invalidated assert_raises_message( tsa.exc.StatementError, "This Connection is closed", conn.execute, select([1]) ) def test_reconnect_on_reentrant(self): conn = self.db.connect() conn.execute(select([1])) assert len(self.dbapi.connections) == 1 self.dbapi.shutdown("rollback") # raises error assert_raises_message( tsa.exc.DBAPIError, "Lost the DB connection on rollback", conn.execute, select([1]) ) assert not conn.closed assert conn.invalidated def test_reconnect_on_reentrant_plus_closewresult(self): conn = self.db.connect(close_with_result=True) self.dbapi.shutdown("rollback") # raises error assert_raises_message( tsa.exc.DBAPIError, "Lost the DB connection on rollback", conn.execute, select([1]) ) assert conn.closed assert conn.invalidated assert_raises_message( tsa.exc.StatementError, "This Connection is closed", conn.execute, select([1]) ) def test_check_disconnect_no_cursor(self): conn = self.db.connect() result = conn.execute(select([1])) result.cursor.close() conn.close() assert_raises_message( tsa.exc.DBAPIError, "cursor closed", list, result ) def test_dialect_initialize_once(self): from sqlalchemy.engine.url import URL from sqlalchemy.engine.default import DefaultDialect dbapi = self.dbapi mock_dialect = Mock() class MyURL(URL): def get_dialect(self): return Dialect class Dialect(DefaultDialect): initialize = Mock() engine = create_engine(MyURL("foo://"), module=dbapi) c1 = engine.connect() engine.dispose() c2 = engine.connect() eq_(Dialect.initialize.call_count, 1) class CursorErrTest(fixtures.TestBase): # this isn't really a "reconnect" test, it's more of # a generic "recovery". maybe this test suite should have been # named "test_error_recovery". def _fixture(self, explode_on_exec, initialize): class DBAPIError(Exception): pass def MockDBAPI(): def cursor(): while True: if explode_on_exec: yield Mock( description=[], close=Mock(side_effect=DBAPIError("explode")), execute=Mock(side_effect=DBAPIError("explode")) ) else: yield Mock( description=[], close=Mock(side_effect=Exception("explode")), ) def connect(): while True: yield Mock( spec=['cursor', 'commit', 'rollback', 'close'], cursor=Mock(side_effect=cursor()),) return Mock( Error=DBAPIError, paramstyle='qmark', connect=Mock(side_effect=connect())) dbapi = MockDBAPI() from sqlalchemy.engine import default url = Mock( get_dialect=lambda: default.DefaultDialect, translate_connect_args=lambda: {}, query={},) eng = testing_engine( url, options=dict(module=dbapi, _initialize=initialize)) eng.pool.logger = Mock() return eng def test_cursor_explode(self): db = self._fixture(False, False) conn = db.connect() result = conn.execute("select foo") result.close() conn.close() eq_( db.pool.logger.error.mock_calls, [call('Error closing cursor', exc_info=True)] ) def test_cursor_shutdown_in_initialize(self): db = self._fixture(True, True) assert_raises_message( exc.SAWarning, "Exception attempting to detect", db.connect ) eq_( db.pool.logger.error.mock_calls, [call('Error closing cursor', exc_info=True)] ) def _assert_invalidated(fn, *args): try: fn(*args) assert False except tsa.exc.DBAPIError as e: if not e.connection_invalidated: raise class RealReconnectTest(fixtures.TestBase): __backend__ = True __requires__ = 'graceful_disconnects', def setup(self): self.engine = engines.reconnecting_engine() def teardown(self): self.engine.dispose() def test_reconnect(self): conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed self.engine.test_shutdown() _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated # one more time self.engine.test_shutdown() _assert_invalidated(conn.execute, select([1])) assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated conn.close() def test_multiple_invalidate(self): c1 = self.engine.connect() c2 = self.engine.connect() eq_(c1.execute(select([1])).scalar(), 1) p1 = self.engine.pool self.engine.test_shutdown() _assert_invalidated(c1.execute, select([1])) p2 = self.engine.pool _assert_invalidated(c2.execute, select([1])) # pool isn't replaced assert self.engine.pool is p2 def test_branched_invalidate_branch_to_parent(self): c1 = self.engine.connect() with patch.object(self.engine.pool, "logger") as logger: c1_branch = c1.connect() eq_(c1_branch.execute(select([1])).scalar(), 1) self.engine.test_shutdown() _assert_invalidated(c1_branch.execute, select([1])) assert c1.invalidated assert c1_branch.invalidated c1_branch._revalidate_connection() assert not c1.invalidated assert not c1_branch.invalidated assert "Invalidate connection" in logger.mock_calls[0][1][0] def test_branched_invalidate_parent_to_branch(self): c1 = self.engine.connect() c1_branch = c1.connect() eq_(c1_branch.execute(select([1])).scalar(), 1) self.engine.test_shutdown() _assert_invalidated(c1.execute, select([1])) assert c1.invalidated assert c1_branch.invalidated c1._revalidate_connection() assert not c1.invalidated assert not c1_branch.invalidated def test_branch_invalidate_state(self): c1 = self.engine.connect() c1_branch = c1.connect() eq_(c1_branch.execute(select([1])).scalar(), 1) self.engine.test_shutdown() _assert_invalidated(c1_branch.execute, select([1])) assert not c1_branch.closed assert not c1_branch._connection_is_valid def test_ensure_is_disconnect_gets_connection(self): def is_disconnect(e, conn, cursor): # connection is still present assert conn.connection is not None # the error usually occurs on connection.cursor(), # though MySQLdb we get a non-working cursor. # assert cursor is None self.engine.dialect.is_disconnect = is_disconnect conn = self.engine.connect() self.engine.test_shutdown() assert_raises( tsa.exc.DBAPIError, conn.execute, select([1]) ) def test_rollback_on_invalid_plain(self): conn = self.engine.connect() trans = conn.begin() conn.invalidate() trans.rollback() @testing.requires.two_phase_transactions def test_rollback_on_invalid_twophase(self): conn = self.engine.connect() trans = conn.begin_twophase() conn.invalidate() trans.rollback() @testing.requires.savepoints def test_rollback_on_invalid_savepoint(self): conn = self.engine.connect() trans = conn.begin() trans2 = conn.begin_nested() conn.invalidate() trans2.rollback() def test_invalidate_twice(self): conn = self.engine.connect() conn.invalidate() conn.invalidate() @testing.skip_if( [lambda: util.py3k, "oracle+cx_oracle"], "Crashes on py3k+cx_oracle") def test_explode_in_initializer(self): engine = engines.testing_engine() def broken_initialize(connection): connection.execute("select fake_stuff from _fake_table") engine.dialect.initialize = broken_initialize # raises a DBAPIError, not an AttributeError assert_raises(exc.DBAPIError, engine.connect) @testing.skip_if( [lambda: util.py3k, "oracle+cx_oracle"], "Crashes on py3k+cx_oracle") def test_explode_in_initializer_disconnect(self): engine = engines.testing_engine() def broken_initialize(connection): connection.execute("select fake_stuff from _fake_table") engine.dialect.initialize = broken_initialize p1 = engine.pool def is_disconnect(e, conn, cursor): return True engine.dialect.is_disconnect = is_disconnect # invalidate() also doesn't screw up assert_raises(exc.DBAPIError, engine.connect) def test_null_pool(self): engine = \ engines.reconnecting_engine(options=dict(poolclass=pool.NullPool)) conn = engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed engine.test_shutdown() _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated def test_close(self): conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed self.engine.test_shutdown() _assert_invalidated(conn.execute, select([1])) conn.close() conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) def test_with_transaction(self): conn = self.engine.connect() trans = conn.begin() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed self.engine.test_shutdown() _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated assert trans.is_active assert_raises_message( tsa.exc.StatementError, "Can't reconnect until invalid transaction is rolled back", conn.execute, select([1])) assert trans.is_active assert_raises_message( tsa.exc.InvalidRequestError, "Can't reconnect until invalid transaction is rolled back", trans.commit ) assert trans.is_active trans.rollback() assert not trans.is_active assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated class RecycleTest(fixtures.TestBase): __backend__ = True def test_basic(self): for threadlocal in False, True: engine = engines.reconnecting_engine( options={'pool_threadlocal': threadlocal}) conn = engine.contextual_connect() eq_(conn.execute(select([1])).scalar(), 1) conn.close() # set the pool recycle down to 1. # we aren't doing this inline with the # engine create since cx_oracle takes way # too long to create the 1st connection and don't # want to build a huge delay into this test. engine.pool._recycle = 1 # kill the DB connection engine.test_shutdown() # wait until past the recycle period time.sleep(2) # can connect, no exception conn = engine.contextual_connect() eq_(conn.execute(select([1])).scalar(), 1) conn.close() class InvalidateDuringResultTest(fixtures.TestBase): __backend__ = True def setup(self): self.engine = engines.reconnecting_engine() self.meta = MetaData(self.engine) table = Table( 'sometable', self.meta, Column('id', Integer, primary_key=True), Column('name', String(50))) self.meta.create_all() table.insert().execute( [{'id': i, 'name': 'row %d' % i} for i in range(1, 100)] ) def teardown(self): self.meta.drop_all() self.engine.dispose() @testing.fails_if([ '+mysqlconnector', '+mysqldb', '+cymysql', '+pymysql', '+pg8000'], "Buffers the result set and doesn't check for connection close") def test_invalidate_on_results(self): conn = self.engine.connect() result = conn.execute('select * from sometable') for x in range(20): result.fetchone() self.engine.test_shutdown() _assert_invalidated(result.fetchone) assert conn.invalidated