diff options
author | Mark Adams <mark@markadams.me> | 2015-04-19 08:12:07 -0500 |
---|---|---|
committer | Mark Adams <mark@markadams.me> | 2015-04-19 08:19:01 -0500 |
commit | a3f2ec37f8660e3c445635ed417ae1100753309d (patch) | |
tree | f540ca536ff546ad95ed601d4c00bbb095f550a2 | |
parent | a97dc6ef58a729c8a90161cc53e8b2ec750cd405 (diff) | |
download | pyjwt-a3f2ec37f8660e3c445635ed417ae1100753309d.tar.gz |
Refactored JWS-specific logic out of PyJWT and into PyJWS superclass
-rw-r--r-- | jwt/__init__.py | 2 | ||||
-rw-r--r-- | jwt/api.py | 260 | ||||
-rw-r--r-- | jwt/api_jws.py | 183 | ||||
-rw-r--r-- | jwt/api_jwt.py | 136 | ||||
-rw-r--r-- | tests/test_api_jwt.py (renamed from tests/test_api.py) | 14 |
5 files changed, 327 insertions, 268 deletions
diff --git a/jwt/__init__.py b/jwt/__init__.py index 8cce5ef..86d95a3 100644 --- a/jwt/__init__.py +++ b/jwt/__init__.py @@ -16,7 +16,7 @@ __license__ = 'MIT' __copyright__ = 'Copyright 2015 José Padilla' -from .api import ( +from .api_jwt import ( encode, decode, register_algorithm, unregister_algorithm, PyJWT ) from .exceptions import ( diff --git a/jwt/api.py b/jwt/api.py deleted file mode 100644 index 273799d..0000000 --- a/jwt/api.py +++ /dev/null @@ -1,260 +0,0 @@ -import binascii -import json -import warnings - -from calendar import timegm -from collections import Mapping -from datetime import datetime, timedelta - -from .algorithms import Algorithm, get_default_algorithms # NOQA -from .compat import string_types, text_type, timedelta_total_seconds -from .exceptions import ( - DecodeError, ExpiredSignatureError, ImmatureSignatureError, - InvalidAlgorithmError, InvalidAudienceError, InvalidIssuedAtError, - InvalidIssuerError -) -from .utils import base64url_decode, base64url_encode, merge_dict - - -class PyJWT(object): - def __init__(self, algorithms=None, options=None): - self._algorithms = get_default_algorithms() - self._valid_algs = set(algorithms) if algorithms is not None else set(self._algorithms) - - # Remove algorithms that aren't on the whitelist - for key in list(self._algorithms.keys()): - if key not in self._valid_algs: - del self._algorithms[key] - - if not options: - options = {} - - default_options = { - 'verify_signature': True, - 'verify_exp': True, - 'verify_nbf': True, - 'verify_iat': True, - 'verify_aud': True, - } - - self.options = merge_dict(default_options, options) - - def register_algorithm(self, alg_id, alg_obj): - """ - Registers a new Algorithm for use when creating and verifying tokens. - """ - if alg_id in self._algorithms: - raise ValueError('Algorithm already has a handler.') - - if not isinstance(alg_obj, Algorithm): - raise TypeError('Object is not of type `Algorithm`') - - self._algorithms[alg_id] = alg_obj - self._valid_algs.add(alg_id) - - def unregister_algorithm(self, alg_id): - """ - Unregisters an Algorithm for use when creating and verifying tokens - Throws KeyError if algorithm is not registered. - """ - if alg_id not in self._algorithms: - raise KeyError('The specified algorithm could not be removed because it is not registered.') - - del self._algorithms[alg_id] - self._valid_algs.remove(alg_id) - - def get_algorithms(self): - """ - Returns a list of supported values for the 'alg' parameter. - """ - return list(self._valid_algs) - - def encode(self, payload, key, algorithm='HS256', headers=None, json_encoder=None): - segments = [] - - if algorithm is None: - algorithm = 'none' - - if algorithm not in self._valid_algs: - pass - - # Check that we get a mapping - if not isinstance(payload, Mapping): - raise TypeError('Expecting a mapping object, as json web token only' - 'support json objects.') - - # Header - header = {'typ': 'JWT', 'alg': algorithm} - - if headers: - header.update(headers) - - json_header = json.dumps( - header, - separators=(',', ':'), - cls=json_encoder - ).encode('utf-8') - - segments.append(base64url_encode(json_header)) - - # Payload - for time_claim in ['exp', 'iat', 'nbf']: - # Convert datetime to a intDate value in known time-format claims - if isinstance(payload.get(time_claim), datetime): - payload[time_claim] = timegm(payload[time_claim].utctimetuple()) - - json_payload = json.dumps( - payload, - separators=(',', ':'), - cls=json_encoder - ).encode('utf-8') - - segments.append(base64url_encode(json_payload)) - - # Segments - signing_input = b'.'.join(segments) - try: - alg_obj = self._algorithms[algorithm] - key = alg_obj.prepare_key(key) - signature = alg_obj.sign(signing_input, key) - - except KeyError: - raise NotImplementedError('Algorithm not supported') - - segments.append(base64url_encode(signature)) - - return b'.'.join(segments) - - def decode(self, jwt, key='', verify=True, algorithms=None, options=None, **kwargs): - payload, signing_input, header, signature = self._load(jwt) - - if verify: - merged_options = merge_dict(self.options, options) - if merged_options.get('verify_signature'): - self._verify_signature(payload, signing_input, header, signature, - key, algorithms) - - self._validate_claims(payload, options=merged_options, **kwargs) - else: - warnings.warn("The verify parameter is deprecated. Please use options instead.", DeprecationWarning) - - return payload - - def _load(self, jwt): - if isinstance(jwt, text_type): - jwt = jwt.encode('utf-8') - try: - signing_input, crypto_segment = jwt.rsplit(b'.', 1) - header_segment, payload_segment = signing_input.split(b'.', 1) - except ValueError: - raise DecodeError('Not enough segments') - - try: - header_data = base64url_decode(header_segment) - except (TypeError, binascii.Error): - raise DecodeError('Invalid header padding') - try: - header = json.loads(header_data.decode('utf-8')) - except ValueError as e: - raise DecodeError('Invalid header string: %s' % e) - if not isinstance(header, Mapping): - raise DecodeError('Invalid header string: must be a json object') - - try: - payload_data = base64url_decode(payload_segment) - except (TypeError, binascii.Error): - raise DecodeError('Invalid payload padding') - try: - payload = json.loads(payload_data.decode('utf-8')) - except ValueError as e: - raise DecodeError('Invalid payload string: %s' % e) - if not isinstance(payload, Mapping): - raise DecodeError('Invalid payload string: must be a json object') - - try: - signature = base64url_decode(crypto_segment) - except (TypeError, binascii.Error): - raise DecodeError('Invalid crypto padding') - - return (payload, signing_input, header, signature) - - def _verify_signature(self, payload, signing_input, header, signature, - key='', algorithms=None): - - alg = header['alg'] - - if algorithms is not None and alg not in algorithms: - raise InvalidAlgorithmError('The specified alg value is not allowed') - - try: - alg_obj = self._algorithms[alg] - key = alg_obj.prepare_key(key) - - if not alg_obj.verify(signing_input, key, signature): - raise DecodeError('Signature verification failed') - - except KeyError: - raise InvalidAlgorithmError('Algorithm not supported') - - def _validate_claims(self, payload, audience=None, issuer=None, leeway=0, - options=None, **kwargs): - if isinstance(leeway, timedelta): - leeway = timedelta_total_seconds(leeway) - - if not isinstance(audience, (string_types, type(None))): - raise TypeError('audience must be a string or None') - - now = timegm(datetime.utcnow().utctimetuple()) - - if 'iat' in payload and options.get('verify_iat'): - try: - iat = int(payload['iat']) - except ValueError: - raise DecodeError('Issued At claim (iat) must be an integer.') - - if iat > (now + leeway): - raise InvalidIssuedAtError('Issued At claim (iat) cannot be in the future.') - - if 'nbf' in payload and options.get('verify_nbf'): - try: - nbf = int(payload['nbf']) - except ValueError: - raise DecodeError('Not Before claim (nbf) must be an integer.') - - if nbf > (now + leeway): - raise ImmatureSignatureError('The token is not yet valid (nbf)') - - if 'exp' in payload and options.get('verify_exp'): - try: - exp = int(payload['exp']) - except ValueError: - raise DecodeError('Expiration Time claim (exp) must be an integer.') - - if exp < (now - leeway): - raise ExpiredSignatureError('Signature has expired') - - if 'aud' in payload and options.get('verify_aud'): - audience_claims = payload['aud'] - if isinstance(audience_claims, string_types): - audience_claims = [audience_claims] - if not isinstance(audience_claims, list): - raise InvalidAudienceError('Invalid claim format in token') - if any(not isinstance(c, string_types) for c in audience_claims): - raise InvalidAudienceError('Invalid claim format in token') - if audience not in audience_claims: - raise InvalidAudienceError('Invalid audience') - elif audience is not None: - # Application specified an audience, but it could not be - # verified since the token does not contain a claim. - raise InvalidAudienceError('No audience claim in token') - - if issuer is not None: - if payload.get('iss') != issuer: - raise InvalidIssuerError('Invalid issuer') - - -_jwt_global_obj = PyJWT() -encode = _jwt_global_obj.encode -decode = _jwt_global_obj.decode -register_algorithm = _jwt_global_obj.register_algorithm -unregister_algorithm = _jwt_global_obj.unregister_algorithm diff --git a/jwt/api_jws.py b/jwt/api_jws.py new file mode 100644 index 0000000..c26044a --- /dev/null +++ b/jwt/api_jws.py @@ -0,0 +1,183 @@ +import binascii +import json +import warnings + +from collections import Mapping + +from .algorithms import Algorithm, get_default_algorithms # NOQA +from .compat import text_type +from .exceptions import DecodeError, InvalidAlgorithmError +from .utils import base64url_decode, base64url_encode, merge_dict + + +class PyJWS(object): + header_typ = 'JWT' + + def __init__(self, algorithms=None, options=None): + self._algorithms = get_default_algorithms() + self._valid_algs = (set(algorithms) if algorithms is not None + else set(self._algorithms)) + + # Remove algorithms that aren't on the whitelist + for key in list(self._algorithms.keys()): + if key not in self._valid_algs: + del self._algorithms[key] + + if not options: + options = {} + + self.options = merge_dict(self._get_default_options(), options) + + @staticmethod + def _get_default_options(): + return { + 'verify_signature': True + } + + def register_algorithm(self, alg_id, alg_obj): + """ + Registers a new Algorithm for use when creating and verifying tokens. + """ + if alg_id in self._algorithms: + raise ValueError('Algorithm already has a handler.') + + if not isinstance(alg_obj, Algorithm): + raise TypeError('Object is not of type `Algorithm`') + + self._algorithms[alg_id] = alg_obj + self._valid_algs.add(alg_id) + + def unregister_algorithm(self, alg_id): + """ + Unregisters an Algorithm for use when creating and verifying tokens + Throws KeyError if algorithm is not registered. + """ + if alg_id not in self._algorithms: + raise KeyError('The specified algorithm could not be removed' + ' because it is not registered.') + + del self._algorithms[alg_id] + self._valid_algs.remove(alg_id) + + def get_algorithms(self): + """ + Returns a list of supported values for the 'alg' parameter. + """ + return list(self._valid_algs) + + def encode(self, payload, key, algorithm='HS256', headers=None, + json_encoder=None): + segments = [] + + if algorithm is None: + algorithm = 'none' + + if algorithm not in self._valid_algs: + pass + + # Header + header = {'typ': self.header_typ, 'alg': algorithm} + + if headers: + header.update(headers) + + json_header = json.dumps( + header, + separators=(',', ':'), + cls=json_encoder + ).encode('utf-8') + + segments.append(base64url_encode(json_header)) + + json_payload = payload.encode('utf-8') + + segments.append(base64url_encode(json_payload)) + + # Segments + signing_input = b'.'.join(segments) + try: + alg_obj = self._algorithms[algorithm] + key = alg_obj.prepare_key(key) + signature = alg_obj.sign(signing_input, key) + + except KeyError: + raise NotImplementedError('Algorithm not supported') + + segments.append(base64url_encode(signature)) + + return b'.'.join(segments) + + def decode(self, jws, key='', verify=True, algorithms=None, options=None, + **kwargs): + payload, signing_input, header, signature = self._load(jws) + + if verify: + merged_options = merge_dict(self.options, options) + if merged_options.get('verify_signature'): + self._verify_signature(payload, signing_input, header, signature, + key, algorithms) + else: + warnings.warn('The verify parameter is deprecated. ' + 'Please use options instead.', DeprecationWarning) + + return payload + + def _load(self, jwt): + if isinstance(jwt, text_type): + jwt = jwt.encode('utf-8') + + try: + signing_input, crypto_segment = jwt.rsplit(b'.', 1) + header_segment, payload_segment = signing_input.split(b'.', 1) + except ValueError: + raise DecodeError('Not enough segments') + + try: + header_data = base64url_decode(header_segment) + except (TypeError, binascii.Error): + raise DecodeError('Invalid header padding') + + try: + header = json.loads(header_data.decode('utf-8')) + except ValueError as e: + raise DecodeError('Invalid header string: %s' % e) + + if not isinstance(header, Mapping): + raise DecodeError('Invalid header string: must be a json object') + + try: + payload = base64url_decode(payload_segment) + except (TypeError, binascii.Error): + raise DecodeError('Invalid payload padding') + + try: + signature = base64url_decode(crypto_segment) + except (TypeError, binascii.Error): + raise DecodeError('Invalid crypto padding') + + return (payload, signing_input, header, signature) + + def _verify_signature(self, payload, signing_input, header, signature, + key='', algorithms=None): + + alg = header['alg'] + + if algorithms is not None and alg not in algorithms: + raise InvalidAlgorithmError('The specified alg value is not allowed') + + try: + alg_obj = self._algorithms[alg] + key = alg_obj.prepare_key(key) + + if not alg_obj.verify(signing_input, key, signature): + raise DecodeError('Signature verification failed') + + except KeyError: + raise InvalidAlgorithmError('Algorithm not supported') + + +_jws_global_obj = PyJWS() +encode = _jws_global_obj.encode +decode = _jws_global_obj.decode +register_algorithm = _jws_global_obj.register_algorithm +unregister_algorithm = _jws_global_obj.unregister_algorithm diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py new file mode 100644 index 0000000..5cef9db --- /dev/null +++ b/jwt/api_jwt.py @@ -0,0 +1,136 @@ +import json + +from calendar import timegm +from collections import Mapping +from datetime import datetime, timedelta + +from .api_jws import PyJWS +from .algorithms import Algorithm, get_default_algorithms # NOQA +from .compat import string_types, timedelta_total_seconds +from .exceptions import ( + DecodeError, ExpiredSignatureError, ImmatureSignatureError, + InvalidAudienceError, InvalidIssuedAtError, + InvalidIssuerError +) +from .utils import merge_dict + + +class PyJWT(PyJWS): + header_type = 'JWT' + + @staticmethod + def _get_default_options(): + return { + 'verify_signature': True, + 'verify_exp': True, + 'verify_nbf': True, + 'verify_iat': True, + 'verify_aud': True, + } + + def encode(self, payload, key, algorithm='HS256', headers=None, json_encoder=None): + # Check that we get a mapping + if not isinstance(payload, Mapping): + raise TypeError('Expecting a mapping object, as json web token only' + 'support json objects.') + + # Payload + for time_claim in ['exp', 'iat', 'nbf']: + # Convert datetime to a intDate value in known time-format claims + if isinstance(payload.get(time_claim), datetime): + payload[time_claim] = timegm(payload[time_claim].utctimetuple()) + + json_payload = json.dumps( + payload, + separators=(',', ':'), + cls=json_encoder + ).encode('utf-8') + + return super(PyJWT, self).encode( + json_payload, key, algorithm, headers, json_encoder + ) + + def decode(self, jwt, key='', verify=True, algorithms=None, options=None, + **kwargs): + payload, signing_input, header, signature = self._load(jwt) + + decoded = super(PyJWT, self).decode(jwt, key, verify, algorithms, + options, **kwargs) + + try: + payload = json.loads(decoded.decode('utf-8')) + except ValueError as e: + raise DecodeError('Invalid payload string: %s' % e) + if not isinstance(payload, Mapping): + raise DecodeError('Invalid payload string: must be a json object') + + if verify: + merged_options = merge_dict(self.options, options) + self._validate_claims(payload, options=merged_options, **kwargs) + + return payload + + def _validate_claims(self, payload, audience=None, issuer=None, leeway=0, + options=None, **kwargs): + if isinstance(leeway, timedelta): + leeway = timedelta_total_seconds(leeway) + + if not isinstance(audience, (string_types, type(None))): + raise TypeError('audience must be a string or None') + + now = timegm(datetime.utcnow().utctimetuple()) + + if 'iat' in payload and options.get('verify_iat'): + try: + iat = int(payload['iat']) + except ValueError: + raise DecodeError('Issued At claim (iat) must be an integer.') + + if iat > (now + leeway): + raise InvalidIssuedAtError('Issued At claim (iat) cannot be in' + ' the future.') + + if 'nbf' in payload and options.get('verify_nbf'): + try: + nbf = int(payload['nbf']) + except ValueError: + raise DecodeError('Not Before claim (nbf) must be an integer.') + + if nbf > (now + leeway): + raise ImmatureSignatureError('The token is not yet valid (nbf)') + + if 'exp' in payload and options.get('verify_exp'): + try: + exp = int(payload['exp']) + except ValueError: + raise DecodeError('Expiration Time claim (exp) must be an' + ' integer.') + + if exp < (now - leeway): + raise ExpiredSignatureError('Signature has expired') + + if 'aud' in payload and options.get('verify_aud'): + audience_claims = payload['aud'] + if isinstance(audience_claims, string_types): + audience_claims = [audience_claims] + if not isinstance(audience_claims, list): + raise InvalidAudienceError('Invalid claim format in token') + if any(not isinstance(c, string_types) for c in audience_claims): + raise InvalidAudienceError('Invalid claim format in token') + if audience not in audience_claims: + raise InvalidAudienceError('Invalid audience') + elif audience is not None: + # Application specified an audience, but it could not be + # verified since the token does not contain a claim. + raise InvalidAudienceError('No audience claim in token') + + if issuer is not None: + if payload.get('iss') != issuer: + raise InvalidIssuerError('Invalid issuer') + + +_jwt_global_obj = PyJWT() +encode = _jwt_global_obj.encode +decode = _jwt_global_obj.decode +register_algorithm = _jwt_global_obj.register_algorithm +unregister_algorithm = _jwt_global_obj.unregister_algorithm diff --git a/tests/test_api.py b/tests/test_api_jwt.py index 4a4399c..7911a6d 100644 --- a/tests/test_api.py +++ b/tests/test_api_jwt.py @@ -7,7 +7,7 @@ from datetime import datetime, timedelta from decimal import Decimal from jwt.algorithms import Algorithm -from jwt.api import PyJWT +from jwt.api_jwt import PyJWT from jwt.exceptions import ( DecodeError, ExpiredSignatureError, ImmatureSignatureError, InvalidAlgorithmError, InvalidAudienceError, InvalidIssuedAtError, @@ -146,9 +146,9 @@ class TestAPI: def test_decode_with_non_mapping_payload_throws_exception(self, jwt): secret = 'secret' - example_jwt = ('eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9' - '.MQ' # == 1 - '.tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8') + example_jwt = ('eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.' + 'MQ.' # == 1 + 'AbcSR3DWum91KOgfKxUHm78rLs_DrrZ1CrDgpUFFzls') with pytest.raises(DecodeError) as context: jwt.decode(example_jwt, secret) @@ -466,9 +466,9 @@ class TestAPI: def test_decode_invalid_payload_string(self, jwt): example_jwt = ( - 'eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9' - '.eyJoZWxsb-kiOiAid29ybGQifQ==' - '.tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8') + 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.aGVsb' + 'G8gd29ybGQ.SIr03zM64awWRdPrAM_61QWsZchAtgDV' + '3pphfHPPWkI') example_secret = 'secret' with pytest.raises(DecodeError) as exc: |