From f6b625886d03f1582a7a99317e84c57d03895339 Mon Sep 17 00:00:00 2001 From: Nikos Sklikas Date: Wed, 2 Jun 2021 11:12:32 +0300 Subject: Move refresh_id_token to validator function --- oauthlib/openid/connect/core/grant_types/refresh_token.py | 6 ++---- oauthlib/openid/connect/core/request_validator.py | 12 ++++++++++++ tests/openid/connect/core/grant_types/test_refresh_token.py | 8 +++++++- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/oauthlib/openid/connect/core/grant_types/refresh_token.py b/oauthlib/openid/connect/core/grant_types/refresh_token.py index 386b57c..43e4499 100644 --- a/oauthlib/openid/connect/core/grant_types/refresh_token.py +++ b/oauthlib/openid/connect/core/grant_types/refresh_token.py @@ -15,8 +15,7 @@ log = logging.getLogger(__name__) class RefreshTokenGrant(GrantTypeBase): - def __init__(self, refresh_id_token=True, request_validator=None, **kwargs): - self.refresh_id_token = refresh_id_token + def __init__(self, request_validator=None, **kwargs): self.proxy_target = OAuth2RefreshTokenGrant( request_validator=request_validator, **kwargs) self.register_token_modifier(self.add_id_token) @@ -29,8 +28,7 @@ class RefreshTokenGrant(GrantTypeBase): The authorization_code version of this method is used to retrieve the nonce accordingly to the code storage. """ - # Treat it as normal OAuth 2 auth code request if openid is not present - if not self.refresh_id_token: + if not self.request_validator.refresh_id_token(request): return token return super().add_id_token(token, token_handler, request) diff --git a/oauthlib/openid/connect/core/request_validator.py b/oauthlib/openid/connect/core/request_validator.py index e8f334b..47c4cd9 100644 --- a/oauthlib/openid/connect/core/request_validator.py +++ b/oauthlib/openid/connect/core/request_validator.py @@ -306,3 +306,15 @@ class RequestValidator(OAuth2RequestValidator): Method is used by: UserInfoEndpoint """ + + def refresh_id_token(self, request): + """Whether the id token should be refreshed. Default, True + + :param request: OAuthlib request. + :type request: oauthlib.common.Request + :rtype: True or False + + Method is used by: + RefreshTokenGrant + """ + return True diff --git a/tests/openid/connect/core/grant_types/test_refresh_token.py b/tests/openid/connect/core/grant_types/test_refresh_token.py index c19de18..8126e1b 100644 --- a/tests/openid/connect/core/grant_types/test_refresh_token.py +++ b/tests/openid/connect/core/grant_types/test_refresh_token.py @@ -60,9 +60,12 @@ class OpenIDRefreshTokenTest(TestCase): self.assertIn('token_type', token) self.assertIn('expires_in', token) self.assertEqual(token['scope'], 'hello openid') + self.mock_validator.refresh_id_token.assert_called_once_with( + self.request + ) def test_refresh_id_token_false(self): - self.auth.refresh_id_token = False + self.mock_validator.refresh_id_token.return_value = False self.mock_validator.get_original_scopes.return_value = [ 'hello', 'openid' ] @@ -80,6 +83,9 @@ class OpenIDRefreshTokenTest(TestCase): self.assertIn('expires_in', token) self.assertEqual(token['scope'], 'hello openid') self.assertNotIn('id_token', token) + self.mock_validator.refresh_id_token.assert_called_once_with( + self.request + ) def test_refresh_token_without_openid_scope(self): self.request.scope = "hello" -- cgit v1.2.1