diff options
-rw-r--r-- | CHANGES | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 105 | ||||
-rw-r--r-- | lib/sqlalchemy/exceptions.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/pool.py | 7 | ||||
-rw-r--r-- | test/engine/reconnect.py | 215 | ||||
-rw-r--r-- | test/sql/testtypes.py | 2 | ||||
-rw-r--r-- | test/testlib/engines.py | 28 |
9 files changed, 331 insertions, 59 deletions
@@ -13,6 +13,15 @@ CHANGES knows that its return type is a Date. We only have a few functions represented so far but will continue to add to the system [ticket:615] + - auto-reconnect support improved; a Connection can now automatically + reconnect after its underlying connection is invalidated, without + needing to connect() again from the engine. This allows an ORM session + bound to a single Connection to not need a reconnect. + Open transactions on the Connection must be rolled back after an invalidation + of the underlying connection else an error is raised. Also fixed + bug where disconnect detect was not being called for cursor(), rollback(), + or commit(). + - added new flag to String and create_engine(), assert_unicode=(True|False|'warn'|None). Defaults to `False` or `None` on create_engine() and String, `'warn'` on the Unicode type. When `True`, diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 122c24bff..a738887f4 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1527,8 +1527,12 @@ class MySQLDialect(default.DefaultDialect): connection.ping() def is_disconnect(self, e): - return isinstance(e, self.dbapi.OperationalError) and \ - e.args[0] in (2006, 2013, 2014, 2045, 2055) + if isinstance(e, self.dbapi.OperationalError): + return e.args[0] in (2006, 2013, 2014, 2045, 2055) + elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, this is the error you get + return "(0, '')" in str(e) + else: + return False def get_default_schema_name(self, connection): try: diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 16dd9427c..e028b1c53 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -237,6 +237,9 @@ class SQLiteDialect(default.DefaultDialect): def oid_column_name(self, column): return "oid" + def is_disconnect(self, e): + return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) + def table_names(self, connection, schema): s = "SELECT name FROM sqlite_master WHERE type='table'" return [row[0] for row in connection.execute(s)] diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index ff2245f39..57c76e7ed 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -542,6 +542,7 @@ class Connection(Connectable): self.__close_with_result = close_with_result self.__savepoint_seq = 0 self.__branch = _branch + self.__invalid = False def _branch(self): """Return a new Connection which references this Connection's @@ -559,12 +560,30 @@ class Connection(Connectable): return self.engine.dialect dialect = property(dialect) + def closed(self): + """return True if this connection is closed.""" + + return not self.__invalid and '_Connection__connection' not in self.__dict__ + closed = property(closed) + + def invalidated(self): + """return True if this connection was invalidated.""" + + return self.__invalid + invalidated = property(invalidated) + def connection(self): "The underlying DB-API connection managed by this Connection." try: return self.__connection except AttributeError: + if self.__invalid: + if self.__transaction is not None: + raise exceptions.InvalidRequestError("Can't reconnect until invalid transaction is rolled back") + self.__connection = self.engine.raw_connection() + self.__invalid = False + return self.__connection raise exceptions.InvalidRequestError("This Connection is closed") connection = property(connection) @@ -603,16 +622,28 @@ class Connection(Connectable): return self - def invalidate(self): - """Invalidate and close the Connection. + def invalidate(self, exception=None): + """Invalidate the underlying DBAPI connection associated with this Connection. The underlying DB-API connection is literally closed (if possible), and is discarded. Its source connection pool will typically lazilly create a new connection to replace it. + + Upon the next usage, this Connection will attempt to reconnect + to the pool with a new connection. + + Transactions in progress remain in an "opened" state (even though + the actual transaction is gone); these must be explicitly + rolled back before a reconnect on this Connection can proceed. This + is to prevent applications from accidentally continuing their transactional + operations in a non-transactional state. + """ - self.__connection.invalidate() - self.__connection = None + if self.__connection.is_valid: + self.__connection.invalidate(exception) + del self.__connection + self.__invalid = True def detach(self): """Detach the underlying DB-API connection from its connection pool. @@ -699,29 +730,31 @@ class Connection(Connectable): if self.engine._should_log_info: self.engine.logger.info("BEGIN") try: - self.engine.dialect.do_begin(self.__connection) + self.engine.dialect.do_begin(self.connection) except Exception, e: - raise exceptions.DBAPIError.instance(None, None, e) + raise self.__handle_dbapi_exception(e, None, None, None) def _rollback_impl(self): - if self.__connection.is_valid: + if not self.closed and not self.invalidated and self.__connection.is_valid: if self.engine._should_log_info: self.engine.logger.info("ROLLBACK") try: - self.engine.dialect.do_rollback(self.__connection) + self.engine.dialect.do_rollback(self.connection) + self.__transaction = None except Exception, e: - raise exceptions.DBAPIError.instance(None, None, e) - self.__transaction = None + raise self.__handle_dbapi_exception(e, None, None, None) + else: + self.__transaction = None def _commit_impl(self): if self.engine._should_log_info: self.engine.logger.info("COMMIT") try: - self.engine.dialect.do_commit(self.__connection) + self.engine.dialect.do_commit(self.connection) + self.__transaction = None except Exception, e: - raise exceptions.DBAPIError.instance(None, None, e) - self.__transaction = None - + raise self.__handle_dbapi_exception(e, None, None, None) + def _savepoint_impl(self, name=None): if name is None: self.__savepoint_seq += 1 @@ -789,6 +822,7 @@ class Connection(Connectable): if not self.__branch: self.__connection.close() self.__connection = None + self.__invalid = False del self.__connection def scalar(self, object, *multiparams, **params): @@ -872,15 +906,32 @@ class Connection(Connectable): self._autocommit(context) return context.result() - def __create_execution_context(self, **kwargs): - return self.engine.dialect.create_execution_context(connection=self, **kwargs) - def __execute_raw(self, context): if context.executemany: self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context) else: self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context) + + def __handle_dbapi_exception(self, e, statement, parameters, cursor): + if not isinstance(e, self.dialect.dbapi.Error): + return e + is_disconnect = self.dialect.is_disconnect(e) + if is_disconnect: + self.invalidate(e) + self.engine.dispose() + if cursor: + cursor.close() + self._autorollback() + if self.__close_with_result: + self.close() + return exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) + def __create_execution_context(self, **kwargs): + try: + return self.engine.dialect.create_execution_context(connection=self, **kwargs) + except Exception, e: + raise self.__handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None) + def _cursor_execute(self, cursor, statement, parameters, context=None): if self.engine._should_log_info: self.engine.logger.info(statement) @@ -888,14 +939,7 @@ class Connection(Connectable): try: self.dialect.do_execute(cursor, statement, parameters, context=context) except Exception, e: - if self.dialect.is_disconnect(e): - self.__connection.invalidate(e=e) - self.engine.dispose() - cursor.close() - self._autorollback() - if self.__close_with_result: - self.close() - raise exceptions.DBAPIError.instance(statement, parameters, e) + raise self.__handle_dbapi_exception(e, statement, parameters, cursor) def _cursor_executemany(self, cursor, statement, parameters, context=None): if self.engine._should_log_info: @@ -904,14 +948,7 @@ class Connection(Connectable): try: self.dialect.do_executemany(cursor, statement, parameters, context=context) except Exception, e: - if self.dialect.is_disconnect(e): - self.__connection.invalidate(e=e) - self.engine.dispose() - cursor.close() - self._autorollback() - if self.__close_with_result: - self.close() - raise exceptions.DBAPIError.instance(statement, parameters, e) + raise self.__handle_dbapi_exception(e, statement, parameters, cursor) # poor man's multimethod/generic function thingy executors = { @@ -990,8 +1027,8 @@ class Transaction(object): def commit(self): if not self._parent._is_active: raise exceptions.InvalidRequestError("This transaction is inactive") - self._is_active = False self._do_commit() + self._is_active = False def _do_commit(self): pass diff --git a/lib/sqlalchemy/exceptions.py b/lib/sqlalchemy/exceptions.py index 8338bc554..530ce3e3a 100644 --- a/lib/sqlalchemy/exceptions.py +++ b/lib/sqlalchemy/exceptions.py @@ -62,7 +62,11 @@ class NoSuchColumnError(KeyError, SQLAlchemyError): class DisconnectionError(SQLAlchemyError): - """Raised within ``Pool`` when a disconnect is detected on a raw DB-API connection.""" + """Raised within ``Pool`` when a disconnect is detected on a raw DB-API connection. + + This error is consumed internally by a connection pool. It can be raised by + a ``PoolListener`` so that the host pool forces a disconnect. + """ class DBAPIError(SQLAlchemyError): @@ -84,7 +88,7 @@ class DBAPIError(SQLAlchemyError): Its type and properties are DB-API implementation specific. """ - def instance(cls, statement, params, orig): + def instance(cls, statement, params, orig, connection_invalidated=False): # Don't ever wrap these, just return them directly as if # DBAPIError didn't exist. if isinstance(orig, (KeyboardInterrupt, SystemExit)): @@ -95,10 +99,10 @@ class DBAPIError(SQLAlchemyError): if name in glob and issubclass(glob[name], DBAPIError): cls = glob[name] - return cls(statement, params, orig) + return cls(statement, params, orig, connection_invalidated) instance = classmethod(instance) - def __init__(self, statement, params, orig): + def __init__(self, statement, params, orig, connection_invalidated=False): try: text = str(orig) except (KeyboardInterrupt, SystemExit): @@ -110,6 +114,7 @@ class DBAPIError(SQLAlchemyError): self.statement = statement self.params = params self.orig = orig + self.connection_invalidated = connection_invalidated def __str__(self): return ' '.join([SQLAlchemyError.__str__(self), diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index ff38f21b8..7a5c2ef0e 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -208,7 +208,12 @@ class _ConnectionRecord(object): if self.connection is not None: if self.__pool._should_log_info: self.__pool.log("Closing connection %s" % repr(self.connection)) - self.connection.close() + try: + self.connection.close() + except: + if self.__pool._should_log_info: + self.__pool.log("Exception closing connection %s" % repr(self.connection)) + def invalidate(self, e=None): if self.__pool._should_log_info: diff --git a/test/engine/reconnect.py b/test/engine/reconnect.py index 7c213695f..f9d692b3d 100644 --- a/test/engine/reconnect.py +++ b/test/engine/reconnect.py @@ -13,37 +13,39 @@ class MockDBAPI(object): self.connections = weakref.WeakKeyDictionary() def connect(self, *args, **kwargs): return MockConnection(self) - + def shutdown(self): + for c in self.connections: + c.explode[0] = True + Error = MockDisconnect + class MockConnection(object): def __init__(self, dbapi): - self.explode = False dbapi.connections[self] = True + self.explode = [False] def rollback(self): pass def commit(self): pass def cursor(self): - return MockCursor(explode=self.explode) + return MockCursor(self) def close(self): pass class MockCursor(object): - def __init__(self, explode): - self.explode = explode + def __init__(self, parent): + self.explode = parent.explode self.description = None def execute(self, *args, **kwargs): - if self.explode: + if self.explode[0]: raise MockDisconnect("Lost the DB connection") else: return def close(self): pass -class ReconnectTest(PersistTest): - def test_reconnect(self): - """test that an 'is_disconnect' condition will invalidate the connection, and additionally - dispose the previous connection pool and recreate.""" - +class MockReconnectTest(PersistTest): + def setUp(self): + global db, dbapi dbapi = MockDBAPI() # create engine using our current dburi @@ -52,6 +54,11 @@ class ReconnectTest(PersistTest): # monkeypatch disconnect checker db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect) + def test_reconnect(self): + """test that an 'is_disconnect' condition will invalidate the connection, and additionally + dispose the previous connection pool and recreate.""" + + pid = id(db.pool) # make a connection @@ -68,17 +75,17 @@ class ReconnectTest(PersistTest): assert len(dbapi.connections) == 2 # set it to fail - conn.connection.connection.explode = True - + dbapi.shutdown() + try: - # execute should fail conn.execute("SELECT 1") assert False - except exceptions.SQLAlchemyError, e: + except exceptions.DBAPIError: pass # assert was invalidated - assert conn.connection.connection is None + assert not conn.closed + assert conn.invalidated # close shouldnt break conn.close() @@ -92,6 +99,182 @@ class ReconnectTest(PersistTest): conn.execute("SELECT 1") conn.close() assert len(dbapi.connections) == 1 + + def test_invalidate_trans(self): + conn = db.connect() + trans = conn.begin() + dbapi.shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError: + pass + + # assert was invalidated + assert len(dbapi.connections) == 0 + assert not conn.closed + assert conn.invalidated + assert trans.is_active + + try: + conn.execute("SELECT 1") + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + try: + trans.commit() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + trans.rollback() + assert not trans.is_active + + conn.execute("SELECT 1") + assert not conn.invalidated + + assert len(dbapi.connections) == 1 + + def test_conn_reusable(self): + conn = db.connect() + + conn.execute("SELECT 1") + + assert len(dbapi.connections) == 1 + + dbapi.shutdown() + + # raises error + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError: + pass + + assert not conn.closed + assert conn.invalidated + + # ensure all connections closed (pool was recycled) + assert len(dbapi.connections) == 0 + + # test reconnects + conn.execute("SELECT 1") + assert not conn.invalidated + assert len(dbapi.connections) == 1 + + +class RealReconnectTest(PersistTest): + def setUp(self): + global engine + engine = engines.reconnecting_engine() + + def tearDown(self): + engine.dispose() + + def test_reconnect(self): + conn = engine.connect() + + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert not conn.closed + assert conn.invalidated + + assert conn.invalidated + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.invalidated + + # one more time + engine.test_shutdown() + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + assert conn.invalidated + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.invalidated + + conn.close() + + def test_close(self): + conn = engine.connect() + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + conn.close() + conn = engine.connect() + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + + def test_with_transaction(self): + conn = engine.connect() + + trans = conn.begin() + + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute("SELECT 1") + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert not conn.closed + assert conn.invalidated + assert trans.is_active + + try: + conn.execute("SELECT 1") + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + try: + trans.commit() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + trans.rollback() + assert not trans.is_active + + assert conn.invalidated + self.assertEquals(conn.execute("SELECT 1").scalar(), 1) + assert not conn.invalidated + if __name__ == '__main__': testbase.main() diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 69696ec64..59acaff2f 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -342,7 +342,7 @@ class UnicodeTest(AssertMixin): assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'" finally: unicode_engine.dispose() - + @testing.unsupported('oracle') def testblanks(self): unicode_table.insert().execute(unicode_varchar=u'') diff --git a/test/testlib/engines.py b/test/testlib/engines.py index 56507618c..b576a1536 100644 --- a/test/testlib/engines.py +++ b/test/testlib/engines.py @@ -68,7 +68,31 @@ def close_open_connections(fn): decorated.__name__ = fn.__name__ return decorated - +class ReconnectFixture(object): + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + conn = self.dbapi.connect(*args, **kwargs) + self.connections.append(conn) + return conn + + def shutdown(self): + for c in list(self.connections): + c.close() + self.connections = [] + +def reconnecting_engine(url=None, options=None): + url = url or config.db_url + dbapi = config.db.dialect.dbapi + engine = testing_engine(url, {'module':ReconnectFixture(dbapi)}) + engine.test_shutdown = engine.dialect.dbapi.shutdown + return engine + def testing_engine(url=None, options=None): """Produce an engine configured by --options with optional overrides.""" @@ -109,3 +133,5 @@ def utf8_engine(url=None, options=None): url = str(url) return testing_engine(url, options) + + |