diff options
-rw-r--r-- | oslo/db/sqlalchemy/compat/__init__.py | 4 | ||||
-rw-r--r-- | oslo/db/sqlalchemy/compat/engine_connect.py | 60 | ||||
-rw-r--r-- | oslo/db/sqlalchemy/compat/utils.py | 1 | ||||
-rw-r--r-- | oslo/db/sqlalchemy/session.py | 21 | ||||
-rw-r--r-- | tests/sqlalchemy/test_engine_connect.py | 68 | ||||
-rw-r--r-- | tests/sqlalchemy/test_exc_filters.py | 14 | ||||
-rw-r--r-- | tests/sqlalchemy/test_sqlalchemy.py | 6 |
7 files changed, 160 insertions, 14 deletions
diff --git a/oslo/db/sqlalchemy/compat/__init__.py b/oslo/db/sqlalchemy/compat/__init__.py index 1510dd0..aaaf200 100644 --- a/oslo/db/sqlalchemy/compat/__init__.py +++ b/oslo/db/sqlalchemy/compat/__init__.py @@ -16,11 +16,13 @@ added at some point but for which oslo.db provides a compatible versions for previous SQLAlchemy versions. """ +from oslo.db.sqlalchemy.compat import engine_connect as _e_conn from oslo.db.sqlalchemy.compat import handle_error as _h_err # trying to get: "from oslo.db.sqlalchemy import compat; compat.handle_error" # flake8 won't let me import handle_error directly +engine_connect = _e_conn.engine_connect handle_error = _h_err.handle_error ExceptionContextImpl = _h_err.ExceptionContextImpl -__all__ = ['handle_error', 'ExceptionContextImpl'] +__all__ = ['engine_connect', 'handle_error', 'ExceptionContextImpl'] diff --git a/oslo/db/sqlalchemy/compat/engine_connect.py b/oslo/db/sqlalchemy/compat/engine_connect.py new file mode 100644 index 0000000..d64d462 --- /dev/null +++ b/oslo/db/sqlalchemy/compat/engine_connect.py @@ -0,0 +1,60 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""Provide forwards compatibility for the engine_connect event. + +See the "engine_connect" event at +http://docs.sqlalchemy.org/en/rel_0_9/core/events.html. + + +""" + +from sqlalchemy.engine import Engine +from sqlalchemy import event + +from oslo.db.sqlalchemy.compat import utils + + +def engine_connect(engine, listener): + """Add an engine_connect listener for the given :class:`.Engine`. + + This listener uses the SQLAlchemy + :meth:`sqlalchemy.event.ConnectionEvents.engine_connect` + event for 0.9.0 and above, and implements an interim listener + for 0.8 versions. + + """ + if utils.sqla_090: + event.listen(engine, "engine_connect", listener) + return + + assert isinstance(engine, Engine), \ + "engine argument must be an Engine instance, not a Connection" + + if not getattr(engine._connection_cls, + '_oslo_engine_connect_wrapper', False): + engine._oslo_engine_connect_events = [] + + class Connection(engine._connection_cls): + _oslo_engine_connect_wrapper = True + + def __init__(self, *arg, **kw): + super(Connection, self).__init__(*arg, **kw) + + _oslo_engine_connect_events = getattr( + self.engine, + '_oslo_engine_connect_events', + False) + if _oslo_engine_connect_events: + for fn in _oslo_engine_connect_events: + fn(self, kw.get('_branch', False)) + engine._connection_cls = Connection + engine._oslo_engine_connect_events.append(listener) diff --git a/oslo/db/sqlalchemy/compat/utils.py b/oslo/db/sqlalchemy/compat/utils.py index a1c6d83..7ba83c3 100644 --- a/oslo/db/sqlalchemy/compat/utils.py +++ b/oslo/db/sqlalchemy/compat/utils.py @@ -21,4 +21,5 @@ _SQLA_VERSION = tuple( sqla_097 = _SQLA_VERSION >= (0, 9, 7) sqla_094 = _SQLA_VERSION >= (0, 9, 4) +sqla_090 = _SQLA_VERSION >= (0, 9, 0) sqla_08 = _SQLA_VERSION >= (0, 8) diff --git a/oslo/db/sqlalchemy/session.py b/oslo/db/sqlalchemy/session.py index 192ad1d..882a9ea 100644 --- a/oslo/db/sqlalchemy/session.py +++ b/oslo/db/sqlalchemy/session.py @@ -293,6 +293,7 @@ from sqlalchemy.sql.expression import select from oslo.db._i18n import _LW from oslo.db import exception from oslo.db import options +from oslo.db.sqlalchemy import compat from oslo.db.sqlalchemy import exc_filters from oslo.db.sqlalchemy import utils @@ -311,12 +312,22 @@ def _thread_yield(dbapi_con, con_record): time.sleep(0) -def _begin_ping_listener(connection): - """Ping the server at transaction begin. +def _connect_ping_listener(connection, branch): + """Ping the server at connection startup. Ping the server at transaction begin and transparently reconnect if a disconnect exception occurs. """ + if branch: + return + + # turn off "close with result". This can also be accomplished + # by branching the connection, however just setting the flag is + # more performant and also doesn't get involved with some + # connection-invalidation awkardness that occurs (see + # https://bitbucket.org/zzzeek/sqlalchemy/issue/3215/) + save_should_close_with_result = connection.should_close_with_result + connection.should_close_with_result = False try: # run a SELECT 1. use a core select() so that # any details like that needed by Oracle, DB2 etc. are handled. @@ -329,6 +340,8 @@ def _begin_ping_listener(connection): # new connections assuming they are good now. # run the select again to re-validate the Connection. connection.scalar(select([1])) + finally: + connection.should_close_with_result = save_should_close_with_result def _setup_logging(connection_debug=0): @@ -389,8 +402,8 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, # register alternate exception handler exc_filters.register_engine(engine) - # register on begin handler - sqlalchemy.event.listen(engine, "begin", _begin_ping_listener) + # register engine connect handler + compat.engine_connect(engine, _connect_ping_listener) # initial connect + test _test_connection(engine, max_retries, retry_interval) diff --git a/tests/sqlalchemy/test_engine_connect.py b/tests/sqlalchemy/test_engine_connect.py new file mode 100644 index 0000000..80d17cb --- /dev/null +++ b/tests/sqlalchemy/test_engine_connect.py @@ -0,0 +1,68 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Test the compatibility layer for the engine_connect() event. + +This event is added as of SQLAlchemy 0.9.0; oslo.db provides a compatibility +layer for prior SQLAlchemy versions. + +""" + +import mock +from oslotest import base as test_base +import sqlalchemy as sqla + +from oslo.db.sqlalchemy.compat import engine_connect + + +class EngineConnectTest(test_base.BaseTestCase): + + def setUp(self): + super(EngineConnectTest, self).setUp() + + self.engine = engine = sqla.create_engine("sqlite://") + self.addCleanup(engine.dispose) + + def test_connect_event(self): + engine = self.engine + + listener = mock.Mock() + engine_connect(engine, listener) + + conn = engine.connect() + self.assertEqual( + listener.mock_calls, + [mock.call(conn, False)] + ) + + conn.close() + + conn2 = engine.connect() + conn2.close() + self.assertEqual( + listener.mock_calls, + [mock.call(conn, False), mock.call(conn2, False)] + ) + + def test_branch(self): + engine = self.engine + + listener = mock.Mock() + engine_connect(engine, listener) + + conn = engine.connect() + branched = conn.connect() + conn.close() + self.assertEqual( + listener.mock_calls, + [mock.call(conn, False), mock.call(branched, True)] + ) diff --git a/tests/sqlalchemy/test_exc_filters.py b/tests/sqlalchemy/test_exc_filters.py index d0ea38b..09b5931 100644 --- a/tests/sqlalchemy/test_exc_filters.py +++ b/tests/sqlalchemy/test_exc_filters.py @@ -22,6 +22,7 @@ import sqlalchemy as sqla from sqlalchemy.orm import mapper from oslo.db import exception +from oslo.db.sqlalchemy import compat from oslo.db.sqlalchemy import exc_filters from oslo.db.sqlalchemy import session from oslo.db.sqlalchemy import test_base @@ -67,7 +68,6 @@ class TestsExceptionFilter(oslo_test_base.BaseTestCase): super(TestsExceptionFilter, self).setUp() self.engine = sqla.create_engine("sqlite://") exc_filters.register_engine(self.engine) - sqla.event.listen(self.engine, "begin", session._begin_ping_listener) self.engine.connect().close() # initialize @contextlib.contextmanager @@ -620,6 +620,8 @@ class TestDBDisconnected(TestsExceptionFilter): def _fixture(self, dialect_name, exception, num_disconnects): engine = self.engine + compat.engine_connect(engine, session._connect_ping_listener) + real_do_execute = engine.dialect.do_execute counter = itertools.count(1) @@ -650,14 +652,14 @@ class TestDBDisconnected(TestsExceptionFilter): self.assertTrue(conn.in_transaction()) with self._fixture(dialect_name, exc_obj, 2): - conn = self.engine.connect() self.assertRaises( exception.DBConnectionError, - conn.begin + self.engine.connect ) - self.assertFalse(conn.closed) - self.assertFalse(conn.in_transaction()) - self.assertTrue(conn.invalidated) + + # test implicit execution + with self._fixture(dialect_name, exc_obj, 1): + self.assertEqual(self.engine.scalar(sqla.select([1])), 1) def test_mysql_ping_listener_disconnected(self): for code in [2006, 2013, 2014, 2045, 2055]: diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index 5099c44..5ba3207 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -614,11 +614,11 @@ class PatchStacktraceTest(test_base.DbTestCase): def test_trace(self): engine = self.engine - engine.connect() + session._add_trace_comments(engine) + conn = engine.connect() with mock.patch.object(engine.dialect, "do_execute") as mock_exec: - session._add_trace_comments(engine) - engine.execute("select * from table") + conn.execute("select * from table") call = mock_exec.mock_calls[0] |