summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Benton <kevin@benton.pub>2017-05-18 11:28:09 -0700
committerKevin Benton <kevin@benton.pub>2017-05-18 11:32:11 -0700
commit7c0106b579e4aa6c19340b8260331485346dbf20 (patch)
tree68011d23293cdd40c5acd4615092f340d341348a
parent311e8b8441b22ba6235bd69360adcd8676e98017 (diff)
downloadoslo-db-7c0106b579e4aa6c19340b8260331485346dbf20.tar.gz
Attach context being used to session/connection info
Attaching a reference to the context to the session/connection is useful for SQLAlchemy event handlers that only get access to the session and need access to the context for additional details. The use case driving this is the related bug where we want to pass conditional update constraints set in the context via the HTTP API down into an SQLAlchemy event listener. Related-Bug: #1493714 Change-Id: I5c08e672782344e8778187e257e03c1a1d1b019a
-rw-r--r--oslo_db/sqlalchemy/enginefacade.py45
-rw-r--r--oslo_db/tests/sqlalchemy/test_enginefacade.py21
2 files changed, 50 insertions, 16 deletions
diff --git a/oslo_db/sqlalchemy/enginefacade.py b/oslo_db/sqlalchemy/enginefacade.py
index 88bca6b..0846dbb 100644
--- a/oslo_db/sqlalchemy/enginefacade.py
+++ b/oslo_db/sqlalchemy/enginefacade.py
@@ -577,24 +577,27 @@ class _TransactionContext(object):
self.flush_on_subtransaction = kw['flush_on_subtransaction']
@contextlib.contextmanager
- def _connection(self, savepoint=False):
+ def _connection(self, savepoint=False, context=None):
if self.connection is None:
try:
if self.session is not None:
# use existing session, which is outer to us
self.connection = self.session.connection()
if savepoint:
- with self.connection.begin_nested():
+ with self.connection.begin_nested(), \
+ self._add_context(self.connection, context):
yield self.connection
else:
- yield self.connection
+ with self._add_context(self.connection, context):
+ yield self.connection
else:
# is outermost
self.connection = self.factory._create_connection(
mode=self.mode)
self.transaction = self.connection.begin()
try:
- yield self.connection
+ with self._add_context(self.connection, context):
+ yield self.connection
self._end_connection_transaction(self.transaction)
except Exception:
self.transaction.rollback()
@@ -612,19 +615,22 @@ class _TransactionContext(object):
else:
# use existing connection, which is outer to us
if savepoint:
- with self.connection.begin_nested():
+ with self.connection.begin_nested(), \
+ self._add_context(self.connection, context):
yield self.connection
else:
- yield self.connection
+ with self._add_context(self.connection, context):
+ yield self.connection
@contextlib.contextmanager
- def _session(self, savepoint=False):
+ def _session(self, savepoint=False, context=None):
if self.session is None:
self.session = self.factory._create_session(
bind=self.connection, mode=self.mode)
try:
self.session.begin()
- yield self.session
+ with self._add_context(self.session, context):
+ yield self.session
self._end_session_transaction(self.session)
except Exception:
self.session.rollback()
@@ -640,12 +646,21 @@ class _TransactionContext(object):
# use existing session, which is outer to us
if savepoint:
with self.session.begin_nested():
- yield self.session
+ with self._add_context(self.session, context):
+ yield self.session
else:
- yield self.session
+ with self._add_context(self.session, context):
+ yield self.session
if self.flush_on_subtransaction:
self.session.flush()
+ @contextlib.contextmanager
+ def _add_context(self, connection, context):
+ restore_context = connection.info.get('using_context')
+ connection.info['using_context'] = context
+ yield connection
+ connection.info['using_context'] = restore_context
+
def _end_session_transaction(self, session):
if self.mode is _WRITER:
session.commit()
@@ -663,7 +678,8 @@ class _TransactionContext(object):
else:
transaction.rollback()
- def _produce_block(self, mode, connection, savepoint, allow_async=False):
+ def _produce_block(self, mode, connection, savepoint, allow_async=False,
+ context=None):
if mode is _WRITER:
self._writer()
elif mode is _ASYNC_READER:
@@ -671,9 +687,9 @@ class _TransactionContext(object):
else:
self._reader(allow_async)
if connection:
- return self._connection(savepoint)
+ return self._connection(savepoint, context=context)
else:
- return self._session(savepoint)
+ return self._session(savepoint, context=context)
def _writer(self):
if self.mode is None:
@@ -1008,7 +1024,8 @@ class _TransactionContextManager(object):
mode=self._mode,
connection=self._connection,
savepoint=self._savepoint,
- allow_async=self._allow_async) as resource:
+ allow_async=self._allow_async,
+ context=context) as resource:
yield resource
else:
yield
diff --git a/oslo_db/tests/sqlalchemy/test_enginefacade.py b/oslo_db/tests/sqlalchemy/test_enginefacade.py
index 566dd36..206c335 100644
--- a/oslo_db/tests/sqlalchemy/test_enginefacade.py
+++ b/oslo_db/tests/sqlalchemy/test_enginefacade.py
@@ -55,6 +55,7 @@ class SingletonConnection(SingletonOnName):
def __init__(self, **kw):
super(SingletonConnection, self).__init__(
"connection", **kw)
+ self.info = {}
class SingletonEngine(SingletonOnName):
@@ -113,14 +114,16 @@ class MockFacadeTest(oslo_test_base.BaseTestCase):
writer_conn = SingletonConnection()
writer_engine = SingletonEngine(writer_conn)
writer_session = mock.Mock(
- connection=mock.Mock(return_value=writer_conn))
+ connection=mock.Mock(return_value=writer_conn),
+ info={})
writer_maker = mock.Mock(return_value=writer_session)
if self.slave_uri:
async_reader_conn = SingletonConnection()
async_reader_engine = SingletonEngine(async_reader_conn)
async_reader_session = mock.Mock(
- connection=mock.Mock(return_value=async_reader_conn))
+ connection=mock.Mock(return_value=async_reader_conn),
+ info={})
async_reader_maker = mock.Mock(return_value=async_reader_session)
else:
@@ -770,6 +773,20 @@ class MockFacadeTest(oslo_test_base.BaseTestCase):
with self._assert_reader_session(makers) as session:
session.execute("test1")
+ def test_using_context_present_in_session_info(self):
+ context = oslo_context.RequestContext()
+
+ with enginefacade.reader.using(context) as session:
+ self.assertEqual(context, session.info['using_context'])
+ self.assertIsNone(session.info['using_context'])
+
+ def test_using_context_present_in_connection_info(self):
+ context = oslo_context.RequestContext()
+
+ with enginefacade.writer.connection.using(context) as connection:
+ self.assertEqual(context, connection.info['using_context'])
+ self.assertIsNone(connection.info['using_context'])
+
def test_using_reader_rollback_reader_session(self):
enginefacade.configure(rollback_reader_sessions=True)