summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTheron Luhn <theron@luhn.com>2021-11-16 23:38:30 -0800
committerAsif Saif Uddin <auvipy@gmail.com>2021-12-13 11:32:06 +0600
commit55ce48b7218ee25794822141c5844eec4a0ff8d9 (patch)
tree43f5be3223d894376094ea8710a865cd67b918be
parentea5ef62290ee306b20e4b57270d0a0575788a461 (diff)
downloadoauthlib-55ce48b7218ee25794822141c5844eec4a0ff8d9.tar.gz
Add support for CORS in the token endpoint.
-rw-r--r--oauthlib/oauth2/rfc6749/grant_types/authorization_code.py19
-rw-r--r--oauthlib/oauth2/rfc6749/request_validator.py25
-rw-r--r--tests/oauth2/rfc6749/grant_types/test_authorization_code.py41
-rw-r--r--tests/oauth2/rfc6749/test_request_validator.py3
4 files changed, 88 insertions, 0 deletions
diff --git a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
index 97aeca9..b799823 100644
--- a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
+++ b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
@@ -10,6 +10,7 @@ import logging
from oauthlib import common
from .. import errors
+from ..utils import is_secure_transport
from .base import GrantTypeBase
log = logging.getLogger(__name__)
@@ -312,6 +313,7 @@ class AuthorizationCodeGrant(GrantTypeBase):
self.request_validator.save_token(token, request)
self.request_validator.invalidate_authorization_code(
request.client_id, request.code, request)
+ headers.update(self._create_cors_headers(request))
return headers, json.dumps(token), 200
def validate_authorization_request(self, request):
@@ -545,3 +547,20 @@ class AuthorizationCodeGrant(GrantTypeBase):
if challenge_method in self._code_challenge_methods:
return self._code_challenge_methods[challenge_method](verifier, challenge)
raise NotImplementedError('Unknown challenge_method %s' % challenge_method)
+
+ def _create_cors_headers(self, request):
+ """If CORS is allowed, create the appropriate headers."""
+ if 'origin' not in request.headers:
+ return {}
+
+ origin = request.headers['origin']
+ if not is_secure_transport(origin):
+ log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin)
+ return {}
+ elif not self.request_validator.is_origin_allowed(
+ request.client_id, origin, request):
+ log.debug('Invalid origin "%s", CORS not allowed.', origin)
+ return {}
+ else:
+ log.debug('Valid origin "%s", injecting CORS headers.', origin)
+ return {'Access-Control-Allow-Origin': origin}
diff --git a/oauthlib/oauth2/rfc6749/request_validator.py b/oauthlib/oauth2/rfc6749/request_validator.py
index 817d594..610a708 100644
--- a/oauthlib/oauth2/rfc6749/request_validator.py
+++ b/oauthlib/oauth2/rfc6749/request_validator.py
@@ -649,3 +649,28 @@ class RequestValidator:
"""
raise NotImplementedError('Subclasses must implement this method.')
+
+ def is_origin_allowed(self, client_id, origin, request, *args, **kwargs):
+ """Indicate if the given origin is allowed to access the token endpoint
+ via Cross-Origin Resource Sharing (CORS). CORS is used by browser-based
+ clients, such as Single-Page Applications, to perform the Authorization
+ Code Grant.
+
+ (Note: If performing Authorization Code Grant via a public client such
+ as a browser, you should use PKCE as well.)
+
+ If this method returns true, the appropriate CORS headers will be added
+ to the response. By default this method always returns False, meaning
+ CORS is disabled.
+
+ :param client_id: Unicode client identifier.
+ :param redirect_uri: Unicode origin.
+ :param request: OAuthlib request.
+ :type request: oauthlib.common.Request
+ :rtype: bool
+
+ Method is used by:
+ - Authorization Code Grant
+
+ """
+ return False
diff --git a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py
index dec5323..77e1a81 100644
--- a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py
+++ b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py
@@ -28,6 +28,7 @@ class AuthorizationCodeGrantTest(TestCase):
self.mock_validator = mock.MagicMock()
self.mock_validator.is_pkce_required.return_value = False
self.mock_validator.get_code_challenge.return_value = None
+ self.mock_validator.is_origin_allowed.return_value = False
self.mock_validator.authenticate_client.side_effect = self.set_client
self.auth = AuthorizationCodeGrant(request_validator=self.mock_validator)
@@ -339,3 +340,43 @@ class AuthorizationCodeGrantTest(TestCase):
)
self.auth.create_authorization_response(self.request, bearer)
self.mock_validator.save_token.assert_called_once()
+
+ # CORS
+
+ def test_create_cors_headers(self):
+ bearer = BearerToken(self.mock_validator)
+ self.request.headers['origin'] = 'https://foo.bar'
+ self.mock_validator.is_origin_allowed.return_value = True
+
+ headers = self.auth.create_token_response(self.request, bearer)[0]
+ self.assertEqual(
+ headers['Access-Control-Allow-Origin'], 'https://foo.bar'
+ )
+ self.mock_validator.is_origin_allowed.assert_called_once_with(
+ 'abcdef', 'https://foo.bar', self.request
+ )
+
+ def test_create_cors_headers_no_origin(self):
+ bearer = BearerToken(self.mock_validator)
+ headers = self.auth.create_token_response(self.request, bearer)[0]
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.mock_validator.is_origin_allowed.assert_not_called()
+
+ def test_create_cors_headers_insecure_origin(self):
+ bearer = BearerToken(self.mock_validator)
+ self.request.headers['origin'] = 'http://foo.bar'
+
+ headers = self.auth.create_token_response(self.request, bearer)[0]
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.mock_validator.is_origin_allowed.assert_not_called()
+
+ def test_create_cors_headers_invalid_origin(self):
+ bearer = BearerToken(self.mock_validator)
+ self.request.headers['origin'] = 'https://foo.bar'
+ self.mock_validator.is_origin_allowed.return_value = False
+
+ headers = self.auth.create_token_response(self.request, bearer)[0]
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.mock_validator.is_origin_allowed.assert_called_once_with(
+ 'abcdef', 'https://foo.bar', self.request
+ )
diff --git a/tests/oauth2/rfc6749/test_request_validator.py b/tests/oauth2/rfc6749/test_request_validator.py
index 9688b5a..7a8d06b 100644
--- a/tests/oauth2/rfc6749/test_request_validator.py
+++ b/tests/oauth2/rfc6749/test_request_validator.py
@@ -46,3 +46,6 @@ class RequestValidatorTest(TestCase):
self.assertRaises(NotImplementedError, v.validate_user,
'username', 'password', 'client', 'request')
self.assertTrue(v.client_authentication_required('r'))
+ self.assertFalse(
+ v.is_origin_allowed('client_id', 'https://foo.bar', 'r')
+ )