summaryrefslogtreecommitdiff
path: root/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
diff options
context:
space:
mode:
Diffstat (limited to 'oauthlib/oauth2/rfc6749/grant_types/authorization_code.py')
-rw-r--r--oauthlib/oauth2/rfc6749/grant_types/authorization_code.py110
1 files changed, 110 insertions, 0 deletions
diff --git a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
index 850d70a..d56330a 100644
--- a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
+++ b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
@@ -5,6 +5,8 @@ oauthlib.oauth2.rfc6749.grant_types
"""
from __future__ import absolute_import, unicode_literals
+import base64
+import hashlib
import json
import logging
@@ -17,6 +19,52 @@ from .base import GrantTypeBase
log = logging.getLogger(__name__)
+def code_challenge_method_s256(verifier, challenge):
+ """
+ If the "code_challenge_method" from `Section 4.3`_ was "S256", the
+ received "code_verifier" is hashed by SHA-256, base64url-encoded, and
+ then compared to the "code_challenge", i.e.:
+
+ BASE64URL-ENCODE(SHA256(ASCII(code_verifier))) == code_challenge
+
+ How to implement a base64url-encoding
+ function without padding, based upon the standard base64-encoding
+ function that uses padding.
+
+ To be concrete, example C# code implementing these functions is shown
+ below. Similar code could be used in other languages.
+
+ static string base64urlencode(byte [] arg)
+ {
+ string s = Convert.ToBase64String(arg); // Regular base64 encoder
+ s = s.Split('=')[0]; // Remove any trailing '='s
+ s = s.Replace('+', '-'); // 62nd char of encoding
+ s = s.Replace('/', '_'); // 63rd char of encoding
+ return s;
+ }
+
+ In python urlsafe_b64encode is already replacing '+' and '/', but preserve
+ the trailing '='. So we have to remove it.
+
+ .. _`Section 4.3`: https://tools.ietf.org/html/rfc7636#section-4.3
+ """
+ return base64.urlsafe_b64encode(
+ hashlib.sha256(verifier.encode()).digest()
+ ).decode().rstrip('=') == challenge
+
+
+def code_challenge_method_plain(verifier, challenge):
+ """
+ If the "code_challenge_method" from `Section 4.3`_ was "plain", they are
+ compared directly, i.e.:
+
+ code_verifier == code_challenge.
+
+ .. _`Section 4.3`: https://tools.ietf.org/html/rfc7636#section-4.3
+ """
+ return verifier == challenge
+
+
class AuthorizationCodeGrant(GrantTypeBase):
"""`Authorization Code Grant`_
@@ -91,12 +139,28 @@ class AuthorizationCodeGrant(GrantTypeBase):
step (C). If valid, the authorization server responds back with
an access token and, optionally, a refresh token.
+ OAuth 2.0 public clients utilizing the Authorization Code Grant are
+ susceptible to the authorization code interception attack.
+
+ A technique to mitigate against the threat through the use of Proof Key for Code
+ Exchange (PKCE, pronounced "pixy") is implemented in the current oauthlib
+ implementation.
+
.. _`Authorization Code Grant`: https://tools.ietf.org/html/rfc6749#section-4.1
+ .. _`PKCE`: https://tools.ietf.org/html/rfc7636
"""
default_response_mode = 'query'
response_types = ['code']
+ # This dict below is private because as RFC mention it:
+ # "S256" is Mandatory To Implement (MTI) on the server.
+ #
+ _code_challenge_methods = {
+ 'plain': code_challenge_method_plain,
+ 'S256': code_challenge_method_s256
+ }
+
def create_authorization_code(self, request):
"""
Generates an authorization grant represented as a dictionary.
@@ -351,6 +415,20 @@ class AuthorizationCodeGrant(GrantTypeBase):
request.client_id, request.response_type)
raise errors.UnauthorizedClientError(request=request)
+ # OPTIONAL. Validate PKCE request or reply with "error"/"invalid_request"
+ # https://tools.ietf.org/html/rfc6749#section-4.4.1
+ if self.request_validator.is_pkce_required(request.client_id, request) is True:
+ if request.code_challenge is None:
+ raise errors.MissingCodeChallengeError(request=request)
+
+ if request.code_challenge is not None:
+ # OPTIONAL, defaults to "plain" if not present in the request.
+ if request.code_challenge_method is None:
+ request.code_challenge_method = "plain"
+
+ if request.code_challenge_method not in self._code_challenge_methods:
+ raise errors.UnsupportedCodeChallengeMethodError(request=request)
+
# OPTIONAL. The scope of the access request as described by Section 3.3
# https://tools.ietf.org/html/rfc6749#section-3.3
self.validate_scopes(request)
@@ -423,6 +501,33 @@ class AuthorizationCodeGrant(GrantTypeBase):
request.client_id, request.client, request.scopes)
raise errors.InvalidGrantError(request=request)
+ # OPTIONAL. Validate PKCE code_verifier
+ challenge = self.request_validator.get_code_challenge(request.code, request)
+
+ if challenge is not None:
+ if request.code_verifier is None:
+ raise errors.MissingCodeVerifierError(request=request)
+
+ challenge_method = self.request_validator.get_code_challenge_method(request.code, request)
+ if challenge_method is None:
+ raise errors.InvalidGrantError(request=request, description="Challenge method not found")
+
+ if challenge_method not in self._code_challenge_methods:
+ raise errors.ServerError(
+ description="code_challenge_method {} is not supported.".format(challenge_method),
+ request=request
+ )
+
+ if not self.validate_code_challenge(challenge,
+ challenge_method,
+ request.code_verifier):
+ log.debug('request provided a invalid code_verifier.')
+ raise errors.InvalidGrantError(request=request)
+ elif self.request_validator.is_pkce_required(request.client_id, request) is True:
+ if request.code_verifier is None:
+ raise errors.MissingCodeVerifierError(request=request)
+ raise errors.InvalidGrantError(request=request, description="Challenge not found")
+
for attr in ('user', 'scopes'):
if getattr(request, attr, None) is None:
log.debug('request.%s was not set on code validation.', attr)
@@ -450,3 +555,8 @@ class AuthorizationCodeGrant(GrantTypeBase):
for validator in self.custom_validators.post_token:
validator(request)
+
+ def validate_code_challenge(self, challenge, challenge_method, verifier):
+ if challenge_method in self._code_challenge_methods:
+ return self._code_challenge_methods[challenge_method](verifier, challenge)
+ raise NotImplementedError('Unknown challenge_method %s' % challenge_method)