diff options
author | Kevin Benton <kevin@benton.pub> | 2017-05-18 11:28:09 -0700 |
---|---|---|
committer | Kevin Benton <kevin@benton.pub> | 2017-05-18 11:32:11 -0700 |
commit | 7c0106b579e4aa6c19340b8260331485346dbf20 (patch) | |
tree | 68011d23293cdd40c5acd4615092f340d341348a | |
parent | 311e8b8441b22ba6235bd69360adcd8676e98017 (diff) | |
download | oslo-db-7c0106b579e4aa6c19340b8260331485346dbf20.tar.gz |
Attach context being used to session/connection info
Attaching a reference to the context to the session/connection
is useful for SQLAlchemy event handlers that only get access to
the session and need access to the context for additional details.
The use case driving this is the related bug where we want to pass
conditional update constraints set in the context via the HTTP API
down into an SQLAlchemy event listener.
Related-Bug: #1493714
Change-Id: I5c08e672782344e8778187e257e03c1a1d1b019a
-rw-r--r-- | oslo_db/sqlalchemy/enginefacade.py | 45 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_enginefacade.py | 21 |
2 files changed, 50 insertions, 16 deletions
diff --git a/oslo_db/sqlalchemy/enginefacade.py b/oslo_db/sqlalchemy/enginefacade.py index 88bca6b..0846dbb 100644 --- a/oslo_db/sqlalchemy/enginefacade.py +++ b/oslo_db/sqlalchemy/enginefacade.py @@ -577,24 +577,27 @@ class _TransactionContext(object): self.flush_on_subtransaction = kw['flush_on_subtransaction'] @contextlib.contextmanager - def _connection(self, savepoint=False): + def _connection(self, savepoint=False, context=None): if self.connection is None: try: if self.session is not None: # use existing session, which is outer to us self.connection = self.session.connection() if savepoint: - with self.connection.begin_nested(): + with self.connection.begin_nested(), \ + self._add_context(self.connection, context): yield self.connection else: - yield self.connection + with self._add_context(self.connection, context): + yield self.connection else: # is outermost self.connection = self.factory._create_connection( mode=self.mode) self.transaction = self.connection.begin() try: - yield self.connection + with self._add_context(self.connection, context): + yield self.connection self._end_connection_transaction(self.transaction) except Exception: self.transaction.rollback() @@ -612,19 +615,22 @@ class _TransactionContext(object): else: # use existing connection, which is outer to us if savepoint: - with self.connection.begin_nested(): + with self.connection.begin_nested(), \ + self._add_context(self.connection, context): yield self.connection else: - yield self.connection + with self._add_context(self.connection, context): + yield self.connection @contextlib.contextmanager - def _session(self, savepoint=False): + def _session(self, savepoint=False, context=None): if self.session is None: self.session = self.factory._create_session( bind=self.connection, mode=self.mode) try: self.session.begin() - yield self.session + with self._add_context(self.session, context): + yield self.session self._end_session_transaction(self.session) except Exception: self.session.rollback() @@ -640,12 +646,21 @@ class _TransactionContext(object): # use existing session, which is outer to us if savepoint: with self.session.begin_nested(): - yield self.session + with self._add_context(self.session, context): + yield self.session else: - yield self.session + with self._add_context(self.session, context): + yield self.session if self.flush_on_subtransaction: self.session.flush() + @contextlib.contextmanager + def _add_context(self, connection, context): + restore_context = connection.info.get('using_context') + connection.info['using_context'] = context + yield connection + connection.info['using_context'] = restore_context + def _end_session_transaction(self, session): if self.mode is _WRITER: session.commit() @@ -663,7 +678,8 @@ class _TransactionContext(object): else: transaction.rollback() - def _produce_block(self, mode, connection, savepoint, allow_async=False): + def _produce_block(self, mode, connection, savepoint, allow_async=False, + context=None): if mode is _WRITER: self._writer() elif mode is _ASYNC_READER: @@ -671,9 +687,9 @@ class _TransactionContext(object): else: self._reader(allow_async) if connection: - return self._connection(savepoint) + return self._connection(savepoint, context=context) else: - return self._session(savepoint) + return self._session(savepoint, context=context) def _writer(self): if self.mode is None: @@ -1008,7 +1024,8 @@ class _TransactionContextManager(object): mode=self._mode, connection=self._connection, savepoint=self._savepoint, - allow_async=self._allow_async) as resource: + allow_async=self._allow_async, + context=context) as resource: yield resource else: yield diff --git a/oslo_db/tests/sqlalchemy/test_enginefacade.py b/oslo_db/tests/sqlalchemy/test_enginefacade.py index 566dd36..206c335 100644 --- a/oslo_db/tests/sqlalchemy/test_enginefacade.py +++ b/oslo_db/tests/sqlalchemy/test_enginefacade.py @@ -55,6 +55,7 @@ class SingletonConnection(SingletonOnName): def __init__(self, **kw): super(SingletonConnection, self).__init__( "connection", **kw) + self.info = {} class SingletonEngine(SingletonOnName): @@ -113,14 +114,16 @@ class MockFacadeTest(oslo_test_base.BaseTestCase): writer_conn = SingletonConnection() writer_engine = SingletonEngine(writer_conn) writer_session = mock.Mock( - connection=mock.Mock(return_value=writer_conn)) + connection=mock.Mock(return_value=writer_conn), + info={}) writer_maker = mock.Mock(return_value=writer_session) if self.slave_uri: async_reader_conn = SingletonConnection() async_reader_engine = SingletonEngine(async_reader_conn) async_reader_session = mock.Mock( - connection=mock.Mock(return_value=async_reader_conn)) + connection=mock.Mock(return_value=async_reader_conn), + info={}) async_reader_maker = mock.Mock(return_value=async_reader_session) else: @@ -770,6 +773,20 @@ class MockFacadeTest(oslo_test_base.BaseTestCase): with self._assert_reader_session(makers) as session: session.execute("test1") + def test_using_context_present_in_session_info(self): + context = oslo_context.RequestContext() + + with enginefacade.reader.using(context) as session: + self.assertEqual(context, session.info['using_context']) + self.assertIsNone(session.info['using_context']) + + def test_using_context_present_in_connection_info(self): + context = oslo_context.RequestContext() + + with enginefacade.writer.connection.using(context) as connection: + self.assertEqual(context, connection.info['using_context']) + self.assertIsNone(connection.info['using_context']) + def test_using_reader_rollback_reader_session(self): enginefacade.configure(rollback_reader_sessions=True) |