From 47c229c5ae0803eae08233f60f846bd401f9543b Mon Sep 17 00:00:00 2001 From: Theron Luhn Date: Tue, 15 Feb 2022 16:33:41 -0800 Subject: Add CORS support for Refresh Token Grant. --- .../rfc6749/grant_types/authorization_code.py | 18 ---------- oauthlib/oauth2/rfc6749/grant_types/base.py | 18 ++++++++++ .../oauth2/rfc6749/grant_types/refresh_token.py | 1 + oauthlib/oauth2/rfc6749/request_validator.py | 1 + .../rfc6749/grant_types/test_refresh_token.py | 41 ++++++++++++++++++++++ 5 files changed, 61 insertions(+), 18 deletions(-) diff --git a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py index b799823..858855a 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py +++ b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py @@ -10,7 +10,6 @@ import logging from oauthlib import common from .. import errors -from ..utils import is_secure_transport from .base import GrantTypeBase log = logging.getLogger(__name__) @@ -547,20 +546,3 @@ class AuthorizationCodeGrant(GrantTypeBase): if challenge_method in self._code_challenge_methods: return self._code_challenge_methods[challenge_method](verifier, challenge) raise NotImplementedError('Unknown challenge_method %s' % challenge_method) - - def _create_cors_headers(self, request): - """If CORS is allowed, create the appropriate headers.""" - if 'origin' not in request.headers: - return {} - - origin = request.headers['origin'] - if not is_secure_transport(origin): - log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin) - return {} - elif not self.request_validator.is_origin_allowed( - request.client_id, origin, request): - log.debug('Invalid origin "%s", CORS not allowed.', origin) - return {} - else: - log.debug('Valid origin "%s", injecting CORS headers.', origin) - return {'Access-Control-Allow-Origin': origin} diff --git a/oauthlib/oauth2/rfc6749/grant_types/base.py b/oauthlib/oauth2/rfc6749/grant_types/base.py index a64f168..ca343a1 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/base.py +++ b/oauthlib/oauth2/rfc6749/grant_types/base.py @@ -10,6 +10,7 @@ from oauthlib.oauth2.rfc6749 import errors, utils from oauthlib.uri_validate import is_absolute_uri from ..request_validator import RequestValidator +from ..utils import is_secure_transport log = logging.getLogger(__name__) @@ -248,3 +249,20 @@ class GrantTypeBase: raise errors.MissingRedirectURIError(request=request) if not is_absolute_uri(request.redirect_uri): raise errors.InvalidRedirectURIError(request=request) + + def _create_cors_headers(self, request): + """If CORS is allowed, create the appropriate headers.""" + if 'origin' not in request.headers: + return {} + + origin = request.headers['origin'] + if not is_secure_transport(origin): + log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin) + return {} + elif not self.request_validator.is_origin_allowed( + request.client_id, origin, request): + log.debug('Invalid origin "%s", CORS not allowed.', origin) + return {} + else: + log.debug('Valid origin "%s", injecting CORS headers.', origin) + return {'Access-Control-Allow-Origin': origin} diff --git a/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py b/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py index f801de4..ce33df0 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py +++ b/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py @@ -69,6 +69,7 @@ class RefreshTokenGrant(GrantTypeBase): log.debug('Issuing new token to client id %r (%r), %r.', request.client_id, request.client, token) + headers.update(self._create_cors_headers(request)) return headers, json.dumps(token), 200 def validate_token_request(self, request): diff --git a/oauthlib/oauth2/rfc6749/request_validator.py b/oauthlib/oauth2/rfc6749/request_validator.py index 610a708..02a13fa 100644 --- a/oauthlib/oauth2/rfc6749/request_validator.py +++ b/oauthlib/oauth2/rfc6749/request_validator.py @@ -671,6 +671,7 @@ class RequestValidator: Method is used by: - Authorization Code Grant + - Refresh Token Grant """ return False diff --git a/tests/oauth2/rfc6749/grant_types/test_refresh_token.py b/tests/oauth2/rfc6749/grant_types/test_refresh_token.py index 1d3e77a..581f2a4 100644 --- a/tests/oauth2/rfc6749/grant_types/test_refresh_token.py +++ b/tests/oauth2/rfc6749/grant_types/test_refresh_token.py @@ -18,6 +18,7 @@ class RefreshTokenGrantTest(TestCase): self.request = Request('http://a.b/path') self.request.grant_type = 'refresh_token' self.request.refresh_token = 'lsdkfhj230' + self.request.client_id = 'abcdef' self.request.client = mock_client self.request.scope = 'foo' self.mock_validator = mock.MagicMock() @@ -168,3 +169,43 @@ class RefreshTokenGrantTest(TestCase): del self.request.scope self.auth.validate_token_request(self.request) self.assertEqual(self.request.scopes, 'foo bar baz'.split()) + + # CORS + + def test_create_cors_headers(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'https://foo.bar' + self.mock_validator.is_origin_allowed.return_value = True + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertEqual( + headers['Access-Control-Allow-Origin'], 'https://foo.bar' + ) + self.mock_validator.is_origin_allowed.assert_called_once_with( + 'abcdef', 'https://foo.bar', self.request + ) + + def test_create_cors_headers_no_origin(self): + bearer = BearerToken(self.mock_validator) + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_not_called() + + def test_create_cors_headers_insecure_origin(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'http://foo.bar' + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_not_called() + + def test_create_cors_headers_invalid_origin(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'https://foo.bar' + self.mock_validator.is_origin_allowed.return_value = False + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_called_once_with( + 'abcdef', 'https://foo.bar', self.request + ) -- cgit v1.2.1