From 55ce48b7218ee25794822141c5844eec4a0ff8d9 Mon Sep 17 00:00:00 2001 From: Theron Luhn Date: Tue, 16 Nov 2021 23:38:30 -0800 Subject: Add support for CORS in the token endpoint. --- .../rfc6749/grant_types/authorization_code.py | 19 ++++++++++ oauthlib/oauth2/rfc6749/request_validator.py | 25 +++++++++++++ .../rfc6749/grant_types/test_authorization_code.py | 41 ++++++++++++++++++++++ tests/oauth2/rfc6749/test_request_validator.py | 3 ++ 4 files changed, 88 insertions(+) diff --git a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py index 97aeca9..b799823 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py +++ b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py @@ -10,6 +10,7 @@ import logging from oauthlib import common from .. import errors +from ..utils import is_secure_transport from .base import GrantTypeBase log = logging.getLogger(__name__) @@ -312,6 +313,7 @@ class AuthorizationCodeGrant(GrantTypeBase): self.request_validator.save_token(token, request) self.request_validator.invalidate_authorization_code( request.client_id, request.code, request) + headers.update(self._create_cors_headers(request)) return headers, json.dumps(token), 200 def validate_authorization_request(self, request): @@ -545,3 +547,20 @@ class AuthorizationCodeGrant(GrantTypeBase): if challenge_method in self._code_challenge_methods: return self._code_challenge_methods[challenge_method](verifier, challenge) raise NotImplementedError('Unknown challenge_method %s' % challenge_method) + + def _create_cors_headers(self, request): + """If CORS is allowed, create the appropriate headers.""" + if 'origin' not in request.headers: + return {} + + origin = request.headers['origin'] + if not is_secure_transport(origin): + log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin) + return {} + elif not self.request_validator.is_origin_allowed( + request.client_id, origin, request): + log.debug('Invalid origin "%s", CORS not allowed.', origin) + return {} + else: + log.debug('Valid origin "%s", injecting CORS headers.', origin) + return {'Access-Control-Allow-Origin': origin} diff --git a/oauthlib/oauth2/rfc6749/request_validator.py b/oauthlib/oauth2/rfc6749/request_validator.py index 817d594..610a708 100644 --- a/oauthlib/oauth2/rfc6749/request_validator.py +++ b/oauthlib/oauth2/rfc6749/request_validator.py @@ -649,3 +649,28 @@ class RequestValidator: """ raise NotImplementedError('Subclasses must implement this method.') + + def is_origin_allowed(self, client_id, origin, request, *args, **kwargs): + """Indicate if the given origin is allowed to access the token endpoint + via Cross-Origin Resource Sharing (CORS). CORS is used by browser-based + clients, such as Single-Page Applications, to perform the Authorization + Code Grant. + + (Note: If performing Authorization Code Grant via a public client such + as a browser, you should use PKCE as well.) + + If this method returns true, the appropriate CORS headers will be added + to the response. By default this method always returns False, meaning + CORS is disabled. + + :param client_id: Unicode client identifier. + :param redirect_uri: Unicode origin. + :param request: OAuthlib request. + :type request: oauthlib.common.Request + :rtype: bool + + Method is used by: + - Authorization Code Grant + + """ + return False diff --git a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py index dec5323..77e1a81 100644 --- a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py +++ b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py @@ -28,6 +28,7 @@ class AuthorizationCodeGrantTest(TestCase): self.mock_validator = mock.MagicMock() self.mock_validator.is_pkce_required.return_value = False self.mock_validator.get_code_challenge.return_value = None + self.mock_validator.is_origin_allowed.return_value = False self.mock_validator.authenticate_client.side_effect = self.set_client self.auth = AuthorizationCodeGrant(request_validator=self.mock_validator) @@ -339,3 +340,43 @@ class AuthorizationCodeGrantTest(TestCase): ) self.auth.create_authorization_response(self.request, bearer) self.mock_validator.save_token.assert_called_once() + + # CORS + + def test_create_cors_headers(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'https://foo.bar' + self.mock_validator.is_origin_allowed.return_value = True + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertEqual( + headers['Access-Control-Allow-Origin'], 'https://foo.bar' + ) + self.mock_validator.is_origin_allowed.assert_called_once_with( + 'abcdef', 'https://foo.bar', self.request + ) + + def test_create_cors_headers_no_origin(self): + bearer = BearerToken(self.mock_validator) + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_not_called() + + def test_create_cors_headers_insecure_origin(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'http://foo.bar' + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_not_called() + + def test_create_cors_headers_invalid_origin(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'https://foo.bar' + self.mock_validator.is_origin_allowed.return_value = False + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_called_once_with( + 'abcdef', 'https://foo.bar', self.request + ) diff --git a/tests/oauth2/rfc6749/test_request_validator.py b/tests/oauth2/rfc6749/test_request_validator.py index 9688b5a..7a8d06b 100644 --- a/tests/oauth2/rfc6749/test_request_validator.py +++ b/tests/oauth2/rfc6749/test_request_validator.py @@ -46,3 +46,6 @@ class RequestValidatorTest(TestCase): self.assertRaises(NotImplementedError, v.validate_user, 'username', 'password', 'client', 'request') self.assertTrue(v.client_authentication_required('r')) + self.assertFalse( + v.is_origin_allowed('client_id', 'https://foo.bar', 'r') + ) -- cgit v1.2.1