summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--oslo_context/context.py7
-rw-r--r--oslo_context/tests/test_context.py14
2 files changed, 19 insertions, 2 deletions
diff --git a/oslo_context/context.py b/oslo_context/context.py
index bd1a6c4..adad06d 100644
--- a/oslo_context/context.py
+++ b/oslo_context/context.py
@@ -53,7 +53,7 @@ class RequestContext(object):
def __init__(self, auth_token=None, user=None, tenant=None, domain=None,
user_domain=None, project_domain=None, is_admin=False,
read_only=False, show_deleted=False, request_id=None,
- resource_uuid=None, overwrite=True):
+ resource_uuid=None, overwrite=True, roles=None):
"""Initialize the RequestContext
:param overwrite: Set to False to ensure that the greenthread local
@@ -69,6 +69,7 @@ class RequestContext(object):
self.read_only = read_only
self.show_deleted = show_deleted
self.resource_uuid = resource_uuid
+ self.roles = roles or []
if not request_id:
request_id = generate_request_id()
self.request_id = request_id
@@ -99,6 +100,7 @@ class RequestContext(object):
'auth_token': self.auth_token,
'request_id': self.request_id,
'resource_uuid': self.resource_uuid,
+ 'roles': self.roles,
'user_identity': user_idt}
def get_logging_values(self):
@@ -143,6 +145,9 @@ class RequestContext(object):
kwargs.setdefault('project_domain',
environ.get('HTTP_X_PROJECT_DOMAIN_ID'))
+ roles = environ.get('HTTP_X_ROLES')
+ kwargs.setdefault('roles', roles.split(',') if roles else [])
+
return cls(**kwargs)
diff --git a/oslo_context/tests/test_context.py b/oslo_context/tests/test_context.py
index 4555da3..eff31c9 100644
--- a/oslo_context/tests/test_context.py
+++ b/oslo_context/tests/test_context.py
@@ -135,12 +135,14 @@ class ContextTest(test_base.BaseTestCase):
project_id = uuid.uuid4().hex
user_domain_id = uuid.uuid4().hex
project_domain_id = uuid.uuid4().hex
+ roles = [uuid.uuid4().hex, uuid.uuid4().hex, uuid.uuid4().hex]
environ = {'HTTP_X_AUTH_TOKEN': auth_token,
'HTTP_X_USER_ID': user_id,
'HTTP_X_PROJECT_ID': project_id,
'HTTP_X_USER_DOMAIN_ID': user_domain_id,
- 'HTTP_X_PROJECT_DOMAIN_ID': project_domain_id}
+ 'HTTP_X_PROJECT_DOMAIN_ID': project_domain_id,
+ 'HTTP_X_ROLES': ','.join(roles)}
ctx = context.RequestContext.from_environ(environ)
@@ -149,6 +151,14 @@ class ContextTest(test_base.BaseTestCase):
self.assertEqual(project_id, ctx.tenant)
self.assertEqual(user_domain_id, ctx.user_domain)
self.assertEqual(project_domain_id, ctx.project_domain)
+ self.assertEqual(roles, ctx.roles)
+
+ def test_from_environ_no_roles(self):
+ ctx = context.RequestContext.from_environ(environ={})
+ self.assertEqual([], ctx.roles)
+
+ ctx = context.RequestContext.from_environ(environ={'HTTP_X_ROLES': ''})
+ self.assertEqual([], ctx.roles)
def test_from_function_and_args(self):
ctx = context.RequestContext(user="user1")
@@ -214,6 +224,7 @@ class ContextTest(test_base.BaseTestCase):
self.assertIn('request_id', d)
self.assertIn('resource_uuid', d)
self.assertIn('user_identity', d)
+ self.assertIn('roles', d)
self.assertEqual(auth_token, d['auth_token'])
self.assertEqual(tenant, d['tenant'])
@@ -228,6 +239,7 @@ class ContextTest(test_base.BaseTestCase):
user_identity = "%s %s %s %s %s" % (user, tenant, domain,
user_domain, project_domain)
self.assertEqual(user_identity, d['user_identity'])
+ self.assertEqual([], d['roles'])
def test_get_logging_values(self):
auth_token = "token1"