diff options
-rw-r--r-- | oslo/db/sqlalchemy/compat/__init__.py | 7 | ||||
-rw-r--r-- | oslo/db/sqlalchemy/exc_filters.py | 43 | ||||
-rw-r--r-- | oslo/db/sqlalchemy/session.py | 58 | ||||
-rw-r--r-- | tests/sqlalchemy/test_exc_filters.py | 88 |
4 files changed, 152 insertions, 44 deletions
diff --git a/oslo/db/sqlalchemy/compat/__init__.py b/oslo/db/sqlalchemy/compat/__init__.py index 2ba9954..1510dd0 100644 --- a/oslo/db/sqlalchemy/compat/__init__.py +++ b/oslo/db/sqlalchemy/compat/__init__.py @@ -16,10 +16,11 @@ added at some point but for which oslo.db provides a compatible versions for previous SQLAlchemy versions. """ -from oslo.db.sqlalchemy.compat import handle_error +from oslo.db.sqlalchemy.compat import handle_error as _h_err # trying to get: "from oslo.db.sqlalchemy import compat; compat.handle_error" # flake8 won't let me import handle_error directly -handle_error = handle_error.handle_error +handle_error = _h_err.handle_error +ExceptionContextImpl = _h_err.ExceptionContextImpl -__all__ = ['handle_error'] +__all__ = ['handle_error', 'ExceptionContextImpl'] diff --git a/oslo/db/sqlalchemy/exc_filters.py b/oslo/db/sqlalchemy/exc_filters.py index 3ec19cc..ab5eee6 100644 --- a/oslo/db/sqlalchemy/exc_filters.py +++ b/oslo/db/sqlalchemy/exc_filters.py @@ -202,6 +202,15 @@ def _raise_operational_errors_directly_filter(operational_error, raise operational_error +# For the db2, the error code is -30081 since the db2 is still not ready +@filters("mysql", sqla_exc.OperationalError, r".*\((?:2002|2003|2006|2013)") +@filters("ibm_db_sa", sqla_exc.OperationalError, r".*(?:-30081)") +def _is_db_connection_error(operational_error, match, engine_name, + is_disconnect): + """Detect the exception as indicating a recoverable error on connect.""" + raise exception.DBConnectionError(operational_error) + + @filters("*", sqla_exc.DBAPIError, r".*") def _raise_for_remaining_DBAPIError(error, match, engine_name, is_disconnect): """Filter for remaining DBAPIErrors and wrap if they represent @@ -253,12 +262,36 @@ def handler(context): for fn, regexp in regexp_reg: match = regexp.match(exc.message) if match: - fn( - exc, - match, - context.connection.dialect.name, - context.is_disconnect) + try: + fn( + exc, + match, + context.connection.dialect.name, + context.is_disconnect) + except exception.DBConnectionError: + context.is_disconnect = True + raise def register_engine(engine): compat.handle_error(engine, handler) + + +def handle_connect_error(engine): + """Provide a special context that will allow on-connect errors + to be raised within the filtering context. + + """ + try: + return engine.connect() + except Exception as e: + if isinstance(e, sqla_exc.StatementError): + s_exc, orig = e, e.orig + else: + s_exc, orig = None, e + + ctx = compat.ExceptionContextImpl( + orig, s_exc, engine, None, + None, None, None, False + ) + handler(ctx) diff --git a/oslo/db/sqlalchemy/session.py b/oslo/db/sqlalchemy/session.py index 2b5a0e7..07a062e 100644 --- a/oslo/db/sqlalchemy/session.py +++ b/oslo/db/sqlalchemy/session.py @@ -279,12 +279,12 @@ Efficient use of soft deletes: """ import functools +import itertools import logging import re import time import six -from sqlalchemy import exc as sqla_exc from sqlalchemy.interfaces import PoolListener import sqlalchemy.orm from sqlalchemy.pool import NullPool, StaticPool @@ -415,18 +415,6 @@ def _mysql_set_mode_callback(engine, sql_mode): _mysql_check_effective_sql_mode(engine) -def _is_db_connection_error(args): - """Return True if error in connecting to db.""" - # NOTE(adam_g): This is currently MySQL specific and needs to be extended - # to support Postgres and others. - # For the db2, the error code is -30081 since the db2 is still not ready - conn_err_codes = ('2002', '2003', '2006', '2013', '-30081') - for err_code in conn_err_codes: - if args.find(err_code) != -1: - return True - return False - - def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, idle_timeout=3600, connection_debug=0, max_pool_size=None, max_overflow=None, @@ -485,38 +473,36 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, if connection_trace and engine.dialect.dbapi.__name__ == 'MySQLdb': _patch_mysqldb_with_stacktrace_comments() - try: - engine.connect() - except sqla_exc.OperationalError as e: - if not _is_db_connection_error(e.args[0]): - raise - - remaining = max_retries - if remaining == -1: - remaining = 'infinite' - while True: - msg = _LW('SQL connection failed. %s attempts left.') - LOG.warning(msg, remaining) - if remaining != 'infinite': - remaining -= 1 - time.sleep(retry_interval) - try: - engine.connect() - break - except sqla_exc.OperationalError as e: - if (remaining != 'infinite' and remaining == 0) or \ - not _is_db_connection_error(e.args[0]): - raise - # register alternate exception handler exc_filters.register_engine(engine) # register on begin handler sqlalchemy.event.listen(engine, "begin", _begin_ping_listener) + # initial connect + test + _test_connection(engine, max_retries, retry_interval) + return engine +def _test_connection(engine, max_retries, retry_interval): + if max_retries == -1: + attempts = itertools.count() + else: + attempts = six.moves.range(max_retries) + de = None + for attempt in attempts: + try: + return exc_filters.handle_connect_error(engine) + except exception.DBConnectionError as de: + msg = _LW('SQL connection failed. %s attempts left.') + LOG.warning(msg, max_retries - attempt) + time.sleep(retry_interval) + else: + if de is not None: + six.reraise(type(de), de) + + class Query(sqlalchemy.orm.query.Query): """Subclass of sqlalchemy.query with soft_delete() method.""" def soft_delete(self, synchronize_session='evaluate'): diff --git a/tests/sqlalchemy/test_exc_filters.py b/tests/sqlalchemy/test_exc_filters.py index b4f25b4..8fa4486 100644 --- a/tests/sqlalchemy/test_exc_filters.py +++ b/tests/sqlalchemy/test_exc_filters.py @@ -21,6 +21,7 @@ import sqlalchemy as sqla from sqlalchemy.orm import mapper from oslo.db import exception +from oslo.db.sqlalchemy import session from oslo.db.sqlalchemy import test_base _TABLE_NAME = '__tmp__test__tmp__' @@ -509,3 +510,90 @@ class TestDBDisconnected(TestsExceptionFilter): self.OperationalError( 'SQL30081N: DB2 Server connection is no longer active') ) + + +class TestDBConnectRetry(TestsExceptionFilter): + + def _run_test(self, dialect_name, exception, count, retries): + counter = itertools.count() + + engine = self.engine + + # empty out the connection pool + engine.dispose() + + connect_fn = engine.dialect.connect + + def cant_connect(*arg, **kw): + if next(counter) < count: + raise exception + else: + return connect_fn(*arg, **kw) + + with self._dbapi_fixture(dialect_name): + with mock.patch.object(engine.dialect, "connect", cant_connect): + return session._test_connection(engine, retries, .01) + + def test_connect_no_retries(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, 0 + ) + # didnt connect because nothing was tried + self.assertIsNone(conn) + + def test_connect_inifinite_retries(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, -1 + ) + # conn is good + self.assertEqual(conn.scalar(sqla.select([1])), 1) + + def test_connect_retry_past_failure(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, 3 + ) + # conn is good + self.assertEqual(conn.scalar(sqla.select([1])), 1) + + def test_connect_retry_not_candidate_exception(self): + self.assertRaises( + sqla.exc.OperationalError, # remember, we pass OperationalErrors + # through at the moment :) + self._run_test, + "mysql", + self.OperationalError("Error: (2015) I can't connect period"), + 2, 3 + ) + + def test_connect_retry_stops_infailure(self): + self.assertRaises( + exception.DBConnectionError, + self._run_test, + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 3, 2 + ) + + def test_db2_error_positive(self): + conn = self._run_test( + "ibm_db_sa", + self.OperationalError("blah blah -30081 blah blah"), + 2, -1 + ) + # conn is good + self.assertEqual(conn.scalar(sqla.select([1])), 1) + + def test_db2_error_negative(self): + self.assertRaises( + sqla.exc.OperationalError, + self._run_test, + "ibm_db_sa", + self.OperationalError("blah blah -39981 blah blah"), + 2, 3 + ) |