diff options
Diffstat (limited to 'oslo_db')
-rw-r--r-- | oslo_db/sqlalchemy/enginefacade.py | 45 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_enginefacade.py | 21 |
2 files changed, 50 insertions, 16 deletions
diff --git a/oslo_db/sqlalchemy/enginefacade.py b/oslo_db/sqlalchemy/enginefacade.py index 88bca6b..0846dbb 100644 --- a/oslo_db/sqlalchemy/enginefacade.py +++ b/oslo_db/sqlalchemy/enginefacade.py @@ -577,24 +577,27 @@ class _TransactionContext(object): self.flush_on_subtransaction = kw['flush_on_subtransaction'] @contextlib.contextmanager - def _connection(self, savepoint=False): + def _connection(self, savepoint=False, context=None): if self.connection is None: try: if self.session is not None: # use existing session, which is outer to us self.connection = self.session.connection() if savepoint: - with self.connection.begin_nested(): + with self.connection.begin_nested(), \ + self._add_context(self.connection, context): yield self.connection else: - yield self.connection + with self._add_context(self.connection, context): + yield self.connection else: # is outermost self.connection = self.factory._create_connection( mode=self.mode) self.transaction = self.connection.begin() try: - yield self.connection + with self._add_context(self.connection, context): + yield self.connection self._end_connection_transaction(self.transaction) except Exception: self.transaction.rollback() @@ -612,19 +615,22 @@ class _TransactionContext(object): else: # use existing connection, which is outer to us if savepoint: - with self.connection.begin_nested(): + with self.connection.begin_nested(), \ + self._add_context(self.connection, context): yield self.connection else: - yield self.connection + with self._add_context(self.connection, context): + yield self.connection @contextlib.contextmanager - def _session(self, savepoint=False): + def _session(self, savepoint=False, context=None): if self.session is None: self.session = self.factory._create_session( bind=self.connection, mode=self.mode) try: self.session.begin() - yield self.session + with self._add_context(self.session, context): + yield self.session self._end_session_transaction(self.session) except Exception: self.session.rollback() @@ -640,12 +646,21 @@ class _TransactionContext(object): # use existing session, which is outer to us if savepoint: with self.session.begin_nested(): - yield self.session + with self._add_context(self.session, context): + yield self.session else: - yield self.session + with self._add_context(self.session, context): + yield self.session if self.flush_on_subtransaction: self.session.flush() + @contextlib.contextmanager + def _add_context(self, connection, context): + restore_context = connection.info.get('using_context') + connection.info['using_context'] = context + yield connection + connection.info['using_context'] = restore_context + def _end_session_transaction(self, session): if self.mode is _WRITER: session.commit() @@ -663,7 +678,8 @@ class _TransactionContext(object): else: transaction.rollback() - def _produce_block(self, mode, connection, savepoint, allow_async=False): + def _produce_block(self, mode, connection, savepoint, allow_async=False, + context=None): if mode is _WRITER: self._writer() elif mode is _ASYNC_READER: @@ -671,9 +687,9 @@ class _TransactionContext(object): else: self._reader(allow_async) if connection: - return self._connection(savepoint) + return self._connection(savepoint, context=context) else: - return self._session(savepoint) + return self._session(savepoint, context=context) def _writer(self): if self.mode is None: @@ -1008,7 +1024,8 @@ class _TransactionContextManager(object): mode=self._mode, connection=self._connection, savepoint=self._savepoint, - allow_async=self._allow_async) as resource: + allow_async=self._allow_async, + context=context) as resource: yield resource else: yield diff --git a/oslo_db/tests/sqlalchemy/test_enginefacade.py b/oslo_db/tests/sqlalchemy/test_enginefacade.py index 566dd36..206c335 100644 --- a/oslo_db/tests/sqlalchemy/test_enginefacade.py +++ b/oslo_db/tests/sqlalchemy/test_enginefacade.py @@ -55,6 +55,7 @@ class SingletonConnection(SingletonOnName): def __init__(self, **kw): super(SingletonConnection, self).__init__( "connection", **kw) + self.info = {} class SingletonEngine(SingletonOnName): @@ -113,14 +114,16 @@ class MockFacadeTest(oslo_test_base.BaseTestCase): writer_conn = SingletonConnection() writer_engine = SingletonEngine(writer_conn) writer_session = mock.Mock( - connection=mock.Mock(return_value=writer_conn)) + connection=mock.Mock(return_value=writer_conn), + info={}) writer_maker = mock.Mock(return_value=writer_session) if self.slave_uri: async_reader_conn = SingletonConnection() async_reader_engine = SingletonEngine(async_reader_conn) async_reader_session = mock.Mock( - connection=mock.Mock(return_value=async_reader_conn)) + connection=mock.Mock(return_value=async_reader_conn), + info={}) async_reader_maker = mock.Mock(return_value=async_reader_session) else: @@ -770,6 +773,20 @@ class MockFacadeTest(oslo_test_base.BaseTestCase): with self._assert_reader_session(makers) as session: session.execute("test1") + def test_using_context_present_in_session_info(self): + context = oslo_context.RequestContext() + + with enginefacade.reader.using(context) as session: + self.assertEqual(context, session.info['using_context']) + self.assertIsNone(session.info['using_context']) + + def test_using_context_present_in_connection_info(self): + context = oslo_context.RequestContext() + + with enginefacade.writer.connection.using(context) as connection: + self.assertEqual(context, connection.info['using_context']) + self.assertIsNone(connection.info['using_context']) + def test_using_reader_rollback_reader_session(self): enginefacade.configure(rollback_reader_sessions=True) |