summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTheron Luhn <theron@luhn.com>2022-02-15 16:33:41 -0800
committerTheron Luhn <theron@luhn.com>2022-02-15 16:33:41 -0800
commit47c229c5ae0803eae08233f60f846bd401f9543b (patch)
tree9180fb7edd8e65fb4d66b9a4120fceab720ea6b2
parent6b1f5db98d464c31db807b7ab0e0fe43ebca46d0 (diff)
downloadoauthlib-47c229c5ae0803eae08233f60f846bd401f9543b.tar.gz
Add CORS support for Refresh Token Grant.
-rw-r--r--oauthlib/oauth2/rfc6749/grant_types/authorization_code.py18
-rw-r--r--oauthlib/oauth2/rfc6749/grant_types/base.py18
-rw-r--r--oauthlib/oauth2/rfc6749/grant_types/refresh_token.py1
-rw-r--r--oauthlib/oauth2/rfc6749/request_validator.py1
-rw-r--r--tests/oauth2/rfc6749/grant_types/test_refresh_token.py41
5 files changed, 61 insertions, 18 deletions
diff --git a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
index b799823..858855a 100644
--- a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
+++ b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
@@ -10,7 +10,6 @@ import logging
from oauthlib import common
from .. import errors
-from ..utils import is_secure_transport
from .base import GrantTypeBase
log = logging.getLogger(__name__)
@@ -547,20 +546,3 @@ class AuthorizationCodeGrant(GrantTypeBase):
if challenge_method in self._code_challenge_methods:
return self._code_challenge_methods[challenge_method](verifier, challenge)
raise NotImplementedError('Unknown challenge_method %s' % challenge_method)
-
- def _create_cors_headers(self, request):
- """If CORS is allowed, create the appropriate headers."""
- if 'origin' not in request.headers:
- return {}
-
- origin = request.headers['origin']
- if not is_secure_transport(origin):
- log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin)
- return {}
- elif not self.request_validator.is_origin_allowed(
- request.client_id, origin, request):
- log.debug('Invalid origin "%s", CORS not allowed.', origin)
- return {}
- else:
- log.debug('Valid origin "%s", injecting CORS headers.', origin)
- return {'Access-Control-Allow-Origin': origin}
diff --git a/oauthlib/oauth2/rfc6749/grant_types/base.py b/oauthlib/oauth2/rfc6749/grant_types/base.py
index a64f168..ca343a1 100644
--- a/oauthlib/oauth2/rfc6749/grant_types/base.py
+++ b/oauthlib/oauth2/rfc6749/grant_types/base.py
@@ -10,6 +10,7 @@ from oauthlib.oauth2.rfc6749 import errors, utils
from oauthlib.uri_validate import is_absolute_uri
from ..request_validator import RequestValidator
+from ..utils import is_secure_transport
log = logging.getLogger(__name__)
@@ -248,3 +249,20 @@ class GrantTypeBase:
raise errors.MissingRedirectURIError(request=request)
if not is_absolute_uri(request.redirect_uri):
raise errors.InvalidRedirectURIError(request=request)
+
+ def _create_cors_headers(self, request):
+ """If CORS is allowed, create the appropriate headers."""
+ if 'origin' not in request.headers:
+ return {}
+
+ origin = request.headers['origin']
+ if not is_secure_transport(origin):
+ log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin)
+ return {}
+ elif not self.request_validator.is_origin_allowed(
+ request.client_id, origin, request):
+ log.debug('Invalid origin "%s", CORS not allowed.', origin)
+ return {}
+ else:
+ log.debug('Valid origin "%s", injecting CORS headers.', origin)
+ return {'Access-Control-Allow-Origin': origin}
diff --git a/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py b/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py
index f801de4..ce33df0 100644
--- a/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py
+++ b/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py
@@ -69,6 +69,7 @@ class RefreshTokenGrant(GrantTypeBase):
log.debug('Issuing new token to client id %r (%r), %r.',
request.client_id, request.client, token)
+ headers.update(self._create_cors_headers(request))
return headers, json.dumps(token), 200
def validate_token_request(self, request):
diff --git a/oauthlib/oauth2/rfc6749/request_validator.py b/oauthlib/oauth2/rfc6749/request_validator.py
index 610a708..02a13fa 100644
--- a/oauthlib/oauth2/rfc6749/request_validator.py
+++ b/oauthlib/oauth2/rfc6749/request_validator.py
@@ -671,6 +671,7 @@ class RequestValidator:
Method is used by:
- Authorization Code Grant
+ - Refresh Token Grant
"""
return False
diff --git a/tests/oauth2/rfc6749/grant_types/test_refresh_token.py b/tests/oauth2/rfc6749/grant_types/test_refresh_token.py
index 1d3e77a..581f2a4 100644
--- a/tests/oauth2/rfc6749/grant_types/test_refresh_token.py
+++ b/tests/oauth2/rfc6749/grant_types/test_refresh_token.py
@@ -18,6 +18,7 @@ class RefreshTokenGrantTest(TestCase):
self.request = Request('http://a.b/path')
self.request.grant_type = 'refresh_token'
self.request.refresh_token = 'lsdkfhj230'
+ self.request.client_id = 'abcdef'
self.request.client = mock_client
self.request.scope = 'foo'
self.mock_validator = mock.MagicMock()
@@ -168,3 +169,43 @@ class RefreshTokenGrantTest(TestCase):
del self.request.scope
self.auth.validate_token_request(self.request)
self.assertEqual(self.request.scopes, 'foo bar baz'.split())
+
+ # CORS
+
+ def test_create_cors_headers(self):
+ bearer = BearerToken(self.mock_validator)
+ self.request.headers['origin'] = 'https://foo.bar'
+ self.mock_validator.is_origin_allowed.return_value = True
+
+ headers = self.auth.create_token_response(self.request, bearer)[0]
+ self.assertEqual(
+ headers['Access-Control-Allow-Origin'], 'https://foo.bar'
+ )
+ self.mock_validator.is_origin_allowed.assert_called_once_with(
+ 'abcdef', 'https://foo.bar', self.request
+ )
+
+ def test_create_cors_headers_no_origin(self):
+ bearer = BearerToken(self.mock_validator)
+ headers = self.auth.create_token_response(self.request, bearer)[0]
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.mock_validator.is_origin_allowed.assert_not_called()
+
+ def test_create_cors_headers_insecure_origin(self):
+ bearer = BearerToken(self.mock_validator)
+ self.request.headers['origin'] = 'http://foo.bar'
+
+ headers = self.auth.create_token_response(self.request, bearer)[0]
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.mock_validator.is_origin_allowed.assert_not_called()
+
+ def test_create_cors_headers_invalid_origin(self):
+ bearer = BearerToken(self.mock_validator)
+ self.request.headers['origin'] = 'https://foo.bar'
+ self.mock_validator.is_origin_allowed.return_value = False
+
+ headers = self.auth.create_token_response(self.request, bearer)[0]
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.mock_validator.is_origin_allowed.assert_called_once_with(
+ 'abcdef', 'https://foo.bar', self.request
+ )