diff options
-rw-r--r-- | keystoneclient/access.py | 52 | ||||
-rw-r--r-- | keystoneclient/tests/v2_0/test_access.py | 31 | ||||
-rw-r--r-- | keystoneclient/tests/v3/test_access.py | 9 |
3 files changed, 75 insertions, 17 deletions
diff --git a/keystoneclient/access.py b/keystoneclient/access.py index dfd7e9a..3c89cc1 100644 --- a/keystoneclient/access.py +++ b/keystoneclient/access.py @@ -33,35 +33,43 @@ class AccessInfo(dict): """ @classmethod - def factory(cls, resp=None, body=None, region_name=None, **kwargs): + def factory(cls, resp=None, body=None, region_name=None, auth_token=None, + **kwargs): """Create AccessInfo object given a successful auth response & body or a user-provided dict. """ # FIXME(jamielennox): Passing region_name is deprecated. Provide an # appropriate warning. + auth_ref = None if body is not None or len(kwargs): if AccessInfoV3.is_valid(body, **kwargs): - token = None - if resp: - token = resp.headers['X-Subject-Token'] + if resp and not auth_token: + auth_token = resp.headers['X-Subject-Token'] + # NOTE(jamielennox): these return AccessInfo because they + # already have auth_token installed on them. if body: if region_name: body['token']['region_name'] = region_name - return AccessInfoV3(token, **body['token']) + return AccessInfoV3(auth_token, **body['token']) else: - return AccessInfoV3(token, **kwargs) + return AccessInfoV3(auth_token, **kwargs) elif AccessInfoV2.is_valid(body, **kwargs): if body: if region_name: body['access']['region_name'] = region_name - return AccessInfoV2(**body['access']) + auth_ref = AccessInfoV2(**body['access']) else: - return AccessInfoV2(**kwargs) + auth_ref = AccessInfoV2(**kwargs) else: raise NotImplementedError('Unrecognized auth response') else: - return AccessInfoV2(**kwargs) + auth_ref = AccessInfoV2(**kwargs) + + if auth_token: + auth_ref.auth_token = auth_token + + return auth_ref def __init__(self, *args, **kwargs): super(AccessInfo, self).__init__(*args, **kwargs) @@ -110,7 +118,18 @@ class AccessInfo(dict): :returns: str """ - raise NotImplementedError() + return self['auth_token'] + + @auth_token.setter + def auth_token(self, value): + self['auth_token'] = value + + @auth_token.deleter + def auth_token(self): + try: + del self['auth_token'] + except KeyError: + pass @property def expires(self): @@ -395,9 +414,12 @@ class AccessInfoV2(AccessInfo): def has_service_catalog(self): return 'serviceCatalog' in self - @property + @AccessInfo.auth_token.getter def auth_token(self): - return self['token']['id'] + try: + return super(AccessInfoV2, self).auth_token + except KeyError: + return self['token']['id'] @property def expires(self): @@ -568,7 +590,7 @@ class AccessInfoV3(AccessInfo): token=token, region_name=self._region_name) if token: - self.update(auth_token=token) + self.auth_token = token @classmethod def is_valid(cls, body, **kwargs): @@ -583,10 +605,6 @@ class AccessInfoV3(AccessInfo): return 'catalog' in self @property - def auth_token(self): - return self['auth_token'] - - @property def expires(self): return timeutils.parse_isotime(self['expires_at']) diff --git a/keystoneclient/tests/v2_0/test_access.py b/keystoneclient/tests/v2_0/test_access.py index 52cb6b1..f384473 100644 --- a/keystoneclient/tests/v2_0/test_access.py +++ b/keystoneclient/tests/v2_0/test_access.py @@ -165,6 +165,37 @@ class AccessInfoTest(utils.TestCase, testresources.ResourcedTestCase): self.assertEqual(trust_id, token['access']['trust']['id']) + def test_override_auth_token(self): + token = fixture.V2Token() + token.set_scope() + token.add_role() + + new_auth_token = uuid.uuid4().hex + + auth_ref = access.AccessInfo.factory(body=token) + + self.assertEqual(token.token_id, auth_ref.auth_token) + + auth_ref.auth_token = new_auth_token + self.assertEqual(new_auth_token, auth_ref.auth_token) + + del auth_ref.auth_token + self.assertEqual(token.token_id, auth_ref.auth_token) + + def test_override_auth_token_in_factory(self): + token = fixture.V2Token() + token.set_scope() + token.add_role() + + new_auth_token = uuid.uuid4().hex + + auth_ref = access.AccessInfo.factory(body=token, + auth_token=new_auth_token) + + self.assertEqual(new_auth_token, auth_ref.auth_token) + del auth_ref.auth_token + self.assertEqual(token.token_id, auth_ref.auth_token) + def load_tests(loader, tests, pattern): return testresources.OptimisingTestSuite(tests) diff --git a/keystoneclient/tests/v3/test_access.py b/keystoneclient/tests/v3/test_access.py index 4353af7..024ac88 100644 --- a/keystoneclient/tests/v3/test_access.py +++ b/keystoneclient/tests/v3/test_access.py @@ -172,3 +172,12 @@ class AccessInfoTest(utils.TestCase): self.assertEqual(consumer_id, auth_ref['OS-OAUTH1']['consumer_id']) self.assertEqual(access_token_id, auth_ref['OS-OAUTH1']['access_token_id']) + + def test_override_auth_token(self): + token = fixture.V3Token() + token.set_project_scope() + + new_auth_token = uuid.uuid4().hex + auth_ref = access.AccessInfo.factory(body=token, + auth_token=new_auth_token) + self.assertEqual(new_auth_token, auth_ref.auth_token) |