summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--oslo/db/sqlalchemy/compat/__init__.py7
-rw-r--r--oslo/db/sqlalchemy/exc_filters.py43
-rw-r--r--oslo/db/sqlalchemy/session.py58
-rw-r--r--tests/sqlalchemy/test_exc_filters.py88
4 files changed, 152 insertions, 44 deletions
diff --git a/oslo/db/sqlalchemy/compat/__init__.py b/oslo/db/sqlalchemy/compat/__init__.py
index 2ba9954..1510dd0 100644
--- a/oslo/db/sqlalchemy/compat/__init__.py
+++ b/oslo/db/sqlalchemy/compat/__init__.py
@@ -16,10 +16,11 @@ added at some point but for which oslo.db provides a compatible versions
for previous SQLAlchemy versions.
"""
-from oslo.db.sqlalchemy.compat import handle_error
+from oslo.db.sqlalchemy.compat import handle_error as _h_err
# trying to get: "from oslo.db.sqlalchemy import compat; compat.handle_error"
# flake8 won't let me import handle_error directly
-handle_error = handle_error.handle_error
+handle_error = _h_err.handle_error
+ExceptionContextImpl = _h_err.ExceptionContextImpl
-__all__ = ['handle_error']
+__all__ = ['handle_error', 'ExceptionContextImpl']
diff --git a/oslo/db/sqlalchemy/exc_filters.py b/oslo/db/sqlalchemy/exc_filters.py
index 3ec19cc..ab5eee6 100644
--- a/oslo/db/sqlalchemy/exc_filters.py
+++ b/oslo/db/sqlalchemy/exc_filters.py
@@ -202,6 +202,15 @@ def _raise_operational_errors_directly_filter(operational_error,
raise operational_error
+# For the db2, the error code is -30081 since the db2 is still not ready
+@filters("mysql", sqla_exc.OperationalError, r".*\((?:2002|2003|2006|2013)")
+@filters("ibm_db_sa", sqla_exc.OperationalError, r".*(?:-30081)")
+def _is_db_connection_error(operational_error, match, engine_name,
+ is_disconnect):
+ """Detect the exception as indicating a recoverable error on connect."""
+ raise exception.DBConnectionError(operational_error)
+
+
@filters("*", sqla_exc.DBAPIError, r".*")
def _raise_for_remaining_DBAPIError(error, match, engine_name, is_disconnect):
"""Filter for remaining DBAPIErrors and wrap if they represent
@@ -253,12 +262,36 @@ def handler(context):
for fn, regexp in regexp_reg:
match = regexp.match(exc.message)
if match:
- fn(
- exc,
- match,
- context.connection.dialect.name,
- context.is_disconnect)
+ try:
+ fn(
+ exc,
+ match,
+ context.connection.dialect.name,
+ context.is_disconnect)
+ except exception.DBConnectionError:
+ context.is_disconnect = True
+ raise
def register_engine(engine):
compat.handle_error(engine, handler)
+
+
+def handle_connect_error(engine):
+ """Provide a special context that will allow on-connect errors
+ to be raised within the filtering context.
+
+ """
+ try:
+ return engine.connect()
+ except Exception as e:
+ if isinstance(e, sqla_exc.StatementError):
+ s_exc, orig = e, e.orig
+ else:
+ s_exc, orig = None, e
+
+ ctx = compat.ExceptionContextImpl(
+ orig, s_exc, engine, None,
+ None, None, None, False
+ )
+ handler(ctx)
diff --git a/oslo/db/sqlalchemy/session.py b/oslo/db/sqlalchemy/session.py
index 2b5a0e7..07a062e 100644
--- a/oslo/db/sqlalchemy/session.py
+++ b/oslo/db/sqlalchemy/session.py
@@ -279,12 +279,12 @@ Efficient use of soft deletes:
"""
import functools
+import itertools
import logging
import re
import time
import six
-from sqlalchemy import exc as sqla_exc
from sqlalchemy.interfaces import PoolListener
import sqlalchemy.orm
from sqlalchemy.pool import NullPool, StaticPool
@@ -415,18 +415,6 @@ def _mysql_set_mode_callback(engine, sql_mode):
_mysql_check_effective_sql_mode(engine)
-def _is_db_connection_error(args):
- """Return True if error in connecting to db."""
- # NOTE(adam_g): This is currently MySQL specific and needs to be extended
- # to support Postgres and others.
- # For the db2, the error code is -30081 since the db2 is still not ready
- conn_err_codes = ('2002', '2003', '2006', '2013', '-30081')
- for err_code in conn_err_codes:
- if args.find(err_code) != -1:
- return True
- return False
-
-
def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
idle_timeout=3600,
connection_debug=0, max_pool_size=None, max_overflow=None,
@@ -485,38 +473,36 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
if connection_trace and engine.dialect.dbapi.__name__ == 'MySQLdb':
_patch_mysqldb_with_stacktrace_comments()
- try:
- engine.connect()
- except sqla_exc.OperationalError as e:
- if not _is_db_connection_error(e.args[0]):
- raise
-
- remaining = max_retries
- if remaining == -1:
- remaining = 'infinite'
- while True:
- msg = _LW('SQL connection failed. %s attempts left.')
- LOG.warning(msg, remaining)
- if remaining != 'infinite':
- remaining -= 1
- time.sleep(retry_interval)
- try:
- engine.connect()
- break
- except sqla_exc.OperationalError as e:
- if (remaining != 'infinite' and remaining == 0) or \
- not _is_db_connection_error(e.args[0]):
- raise
-
# register alternate exception handler
exc_filters.register_engine(engine)
# register on begin handler
sqlalchemy.event.listen(engine, "begin", _begin_ping_listener)
+ # initial connect + test
+ _test_connection(engine, max_retries, retry_interval)
+
return engine
+def _test_connection(engine, max_retries, retry_interval):
+ if max_retries == -1:
+ attempts = itertools.count()
+ else:
+ attempts = six.moves.range(max_retries)
+ de = None
+ for attempt in attempts:
+ try:
+ return exc_filters.handle_connect_error(engine)
+ except exception.DBConnectionError as de:
+ msg = _LW('SQL connection failed. %s attempts left.')
+ LOG.warning(msg, max_retries - attempt)
+ time.sleep(retry_interval)
+ else:
+ if de is not None:
+ six.reraise(type(de), de)
+
+
class Query(sqlalchemy.orm.query.Query):
"""Subclass of sqlalchemy.query with soft_delete() method."""
def soft_delete(self, synchronize_session='evaluate'):
diff --git a/tests/sqlalchemy/test_exc_filters.py b/tests/sqlalchemy/test_exc_filters.py
index b4f25b4..8fa4486 100644
--- a/tests/sqlalchemy/test_exc_filters.py
+++ b/tests/sqlalchemy/test_exc_filters.py
@@ -21,6 +21,7 @@ import sqlalchemy as sqla
from sqlalchemy.orm import mapper
from oslo.db import exception
+from oslo.db.sqlalchemy import session
from oslo.db.sqlalchemy import test_base
_TABLE_NAME = '__tmp__test__tmp__'
@@ -509,3 +510,90 @@ class TestDBDisconnected(TestsExceptionFilter):
self.OperationalError(
'SQL30081N: DB2 Server connection is no longer active')
)
+
+
+class TestDBConnectRetry(TestsExceptionFilter):
+
+ def _run_test(self, dialect_name, exception, count, retries):
+ counter = itertools.count()
+
+ engine = self.engine
+
+ # empty out the connection pool
+ engine.dispose()
+
+ connect_fn = engine.dialect.connect
+
+ def cant_connect(*arg, **kw):
+ if next(counter) < count:
+ raise exception
+ else:
+ return connect_fn(*arg, **kw)
+
+ with self._dbapi_fixture(dialect_name):
+ with mock.patch.object(engine.dialect, "connect", cant_connect):
+ return session._test_connection(engine, retries, .01)
+
+ def test_connect_no_retries(self):
+ conn = self._run_test(
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 2, 0
+ )
+ # didnt connect because nothing was tried
+ self.assertIsNone(conn)
+
+ def test_connect_inifinite_retries(self):
+ conn = self._run_test(
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 2, -1
+ )
+ # conn is good
+ self.assertEqual(conn.scalar(sqla.select([1])), 1)
+
+ def test_connect_retry_past_failure(self):
+ conn = self._run_test(
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 2, 3
+ )
+ # conn is good
+ self.assertEqual(conn.scalar(sqla.select([1])), 1)
+
+ def test_connect_retry_not_candidate_exception(self):
+ self.assertRaises(
+ sqla.exc.OperationalError, # remember, we pass OperationalErrors
+ # through at the moment :)
+ self._run_test,
+ "mysql",
+ self.OperationalError("Error: (2015) I can't connect period"),
+ 2, 3
+ )
+
+ def test_connect_retry_stops_infailure(self):
+ self.assertRaises(
+ exception.DBConnectionError,
+ self._run_test,
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 3, 2
+ )
+
+ def test_db2_error_positive(self):
+ conn = self._run_test(
+ "ibm_db_sa",
+ self.OperationalError("blah blah -30081 blah blah"),
+ 2, -1
+ )
+ # conn is good
+ self.assertEqual(conn.scalar(sqla.select([1])), 1)
+
+ def test_db2_error_negative(self):
+ self.assertRaises(
+ sqla.exc.OperationalError,
+ self._run_test,
+ "ibm_db_sa",
+ self.OperationalError("blah blah -39981 blah blah"),
+ 2, 3
+ )