diff options
author | José Padilla <jpadilla@webapplicate.com> | 2015-01-18 13:28:18 -0400 |
---|---|---|
committer | José Padilla <jpadilla@webapplicate.com> | 2015-01-18 13:28:18 -0400 |
commit | 6290a1c8b5928d691f9fc91b733550b93b5f8cc7 (patch) | |
tree | ba7a79fe402317c00ace2eadeb5901ea463cb809 | |
parent | 0afba10cf16834e154a59280de089c30de3d9a61 (diff) | |
parent | eff1505b5df81540eb6c3c73b1d8c924fd1c2ca5 (diff) | |
download | pyjwt-6290a1c8b5928d691f9fc91b733550b93b5f8cc7.tar.gz |
Merge pull request #71 from mark-adams/algorithm-registry
Fixes #70. Implemented registry pattern for class-based algorithms
-rw-r--r-- | jwt/__init__.py | 225 | ||||
-rw-r--r-- | jwt/algorithms.py | 200 | ||||
-rw-r--r-- | jwt/utils.py | 14 | ||||
-rwxr-xr-x | setup.py | 3 | ||||
-rw-r--r-- | tests/test_jwt.py | 72 |
5 files changed, 265 insertions, 249 deletions
diff --git a/jwt/__init__.py b/jwt/__init__.py index 3a70913..b9a9986 100644 --- a/jwt/__init__.py +++ b/jwt/__init__.py @@ -5,16 +5,15 @@ Minimum implementation based on this spec: http://self-issued.info/docs/draft-jones-json-web-token-01.html """ -import base64 import binascii -import hashlib -import hmac -from datetime import datetime, timedelta + from calendar import timegm from collections import Mapping +from datetime import datetime, timedelta -from .compat import (json, string_types, text_type, constant_time_compare, - timedelta_total_seconds) +from jwt.utils import base64url_decode, base64url_encode + +from .compat import (json, string_types, text_type, timedelta_total_seconds) __version__ = '0.4.1' @@ -22,6 +21,7 @@ __all__ = [ # Functions 'encode', 'decode', + 'register_algorithm', # Exceptions 'InvalidTokenError', @@ -33,9 +33,25 @@ __all__ = [ # Deprecated aliases 'ExpiredSignature', 'InvalidAudience', - 'InvalidIssuer', + 'InvalidIssuer' ] +_algorithms = {} + + +def register_algorithm(alg_id, alg_obj): + """ Registers a new Algorithm for use when creating and verifying JWTs """ + if alg_id in _algorithms: + raise ValueError('Algorithm already has a handler.') + + if not isinstance(alg_obj, Algorithm): + raise TypeError('Object is not of type `Algorithm`') + + _algorithms[alg_id] = alg_obj + +from jwt.algorithms import Algorithm, _register_default_algorithms # NOQA +_register_default_algorithms() + class InvalidTokenError(Exception): pass @@ -56,187 +72,11 @@ class InvalidAudienceError(InvalidTokenError): class InvalidIssuerError(InvalidTokenError): pass - # Compatibility aliases (deprecated) ExpiredSignature = ExpiredSignatureError InvalidAudience = InvalidAudienceError InvalidIssuer = InvalidIssuerError -signing_methods = { - 'none': lambda msg, key: b'', - 'HS256': lambda msg, key: hmac.new(key, msg, hashlib.sha256).digest(), - 'HS384': lambda msg, key: hmac.new(key, msg, hashlib.sha384).digest(), - 'HS512': lambda msg, key: hmac.new(key, msg, hashlib.sha512).digest() -} - -verify_methods = { - 'HS256': lambda msg, key: hmac.new(key, msg, hashlib.sha256).digest(), - 'HS384': lambda msg, key: hmac.new(key, msg, hashlib.sha384).digest(), - 'HS512': lambda msg, key: hmac.new(key, msg, hashlib.sha512).digest() -} - - -def prepare_HS_key(key): - if not isinstance(key, string_types) and not isinstance(key, bytes): - raise TypeError('Expecting a string- or bytes-formatted key.') - - if isinstance(key, text_type): - key = key.encode('utf-8') - - return key - -prepare_key_methods = { - 'none': lambda key: None, - 'HS256': prepare_HS_key, - 'HS384': prepare_HS_key, - 'HS512': prepare_HS_key -} - -try: - from cryptography.hazmat.primitives import interfaces, hashes - from cryptography.hazmat.primitives.serialization import ( - load_pem_private_key, load_pem_public_key, load_ssh_public_key - ) - from cryptography.hazmat.primitives.asymmetric import ec, padding - from cryptography.hazmat.backends import default_backend - from cryptography.exceptions import InvalidSignature - - def sign_rsa(msg, key, hashalg): - signer = key.signer( - padding.PKCS1v15(), - hashalg - ) - - signer.update(msg) - return signer.finalize() - - def verify_rsa(msg, key, hashalg, sig): - verifier = key.verifier( - sig, - padding.PKCS1v15(), - hashalg - ) - - verifier.update(msg) - - try: - verifier.verify() - return True - except InvalidSignature: - return False - - signing_methods.update({ - 'RS256': lambda msg, key: sign_rsa(msg, key, hashes.SHA256()), - 'RS384': lambda msg, key: sign_rsa(msg, key, hashes.SHA384()), - 'RS512': lambda msg, key: sign_rsa(msg, key, hashes.SHA512()) - }) - - verify_methods.update({ - 'RS256': lambda msg, key, sig: verify_rsa(msg, key, hashes.SHA256(), sig), - 'RS384': lambda msg, key, sig: verify_rsa(msg, key, hashes.SHA384(), sig), - 'RS512': lambda msg, key, sig: verify_rsa(msg, key, hashes.SHA512(), sig) - }) - - def prepare_RS_key(key): - if isinstance(key, interfaces.RSAPrivateKey) or \ - isinstance(key, interfaces.RSAPublicKey): - return key - - if isinstance(key, string_types): - if isinstance(key, text_type): - key = key.encode('utf-8') - - try: - if key.startswith(b'ssh-rsa'): - key = load_ssh_public_key(key, backend=default_backend()) - else: - key = load_pem_private_key(key, password=None, backend=default_backend()) - except ValueError: - key = load_pem_public_key(key, backend=default_backend()) - else: - raise TypeError('Expecting a PEM-formatted key.') - - return key - - prepare_key_methods.update({ - 'RS256': prepare_RS_key, - 'RS384': prepare_RS_key, - 'RS512': prepare_RS_key - }) - - def sign_ecdsa(msg, key, hashalg): - signer = key.signer(ec.ECDSA(hashalg)) - - signer.update(msg) - return signer.finalize() - - def verify_ecdsa(msg, key, hashalg, sig): - verifier = key.verifier(sig, ec.ECDSA(hashalg)) - - verifier.update(msg) - - try: - verifier.verify() - return True - except InvalidSignature: - return False - - signing_methods.update({ - 'ES256': lambda msg, key: sign_ecdsa(msg, key, hashes.SHA256()), - 'ES384': lambda msg, key: sign_ecdsa(msg, key, hashes.SHA384()), - 'ES512': lambda msg, key: sign_ecdsa(msg, key, hashes.SHA512()), - }) - - verify_methods.update({ - 'ES256': lambda msg, key, sig: verify_ecdsa(msg, key, hashes.SHA256(), sig), - 'ES384': lambda msg, key, sig: verify_ecdsa(msg, key, hashes.SHA384(), sig), - 'ES512': lambda msg, key, sig: verify_ecdsa(msg, key, hashes.SHA512(), sig), - }) - - def prepare_ES_key(key): - if isinstance(key, interfaces.EllipticCurvePrivateKey) or \ - isinstance(key, interfaces.EllipticCurvePublicKey): - return key - - if isinstance(key, string_types): - if isinstance(key, text_type): - key = key.encode('utf-8') - - # Attempt to load key. We don't know if it's - # a Signing Key or a Verifying Key, so we try - # the Verifying Key first. - try: - key = load_pem_public_key(key, backend=default_backend()) - except ValueError: - key = load_pem_private_key(key, password=None, backend=default_backend()) - - else: - raise TypeError('Expecting a PEM-formatted key.') - - return key - - prepare_key_methods.update({ - 'ES256': prepare_ES_key, - 'ES384': prepare_ES_key, - 'ES512': prepare_ES_key - }) - -except ImportError: - pass - - -def base64url_decode(input): - rem = len(input) % 4 - - if rem > 0: - input += b'=' * (4 - rem) - - return base64.urlsafe_b64decode(input) - - -def base64url_encode(input): - return base64.urlsafe_b64encode(input).replace(b'=', b'') - def header(jwt): if isinstance(jwt, text_type): @@ -290,8 +130,10 @@ def encode(payload, key, algorithm='HS256', headers=None, json_encoder=None): # Segments signing_input = b'.'.join(segments) try: - key = prepare_key_methods[algorithm](key) - signature = signing_methods[algorithm](signing_input, key) + alg_obj = _algorithms[algorithm] + key = alg_obj.prepare_key(key) + signature = alg_obj.sign(signing_input, key) + except KeyError: raise NotImplementedError('Algorithm not supported') @@ -360,17 +202,12 @@ def verify_signature(payload, signing_input, header, signature, key='', raise TypeError('audience must be a string or None') try: - algorithm = header['alg'].upper() - key = prepare_key_methods[algorithm](key) + alg_obj = _algorithms[header['alg'].upper()] + key = alg_obj.prepare_key(key) - if algorithm.startswith('HS'): - expected = verify_methods[algorithm](signing_input, key) + if not alg_obj.verify(signing_input, key, signature): + raise DecodeError('Signature verification failed') - if not constant_time_compare(signature, expected): - raise DecodeError('Signature verification failed') - else: - if not verify_methods[algorithm](signing_input, key, signature): - raise DecodeError('Signature verification failed') except KeyError: raise DecodeError('Algorithm not supported') diff --git a/jwt/algorithms.py b/jwt/algorithms.py new file mode 100644 index 0000000..89ea75b --- /dev/null +++ b/jwt/algorithms.py @@ -0,0 +1,200 @@ +import hashlib +import hmac + +from jwt import register_algorithm +from jwt.compat import constant_time_compare, string_types, text_type + +try: + from cryptography.hazmat.primitives import interfaces, hashes + from cryptography.hazmat.primitives.serialization import ( + load_pem_private_key, load_pem_public_key, load_ssh_public_key + ) + from cryptography.hazmat.primitives.asymmetric import ec, padding + from cryptography.hazmat.backends import default_backend + from cryptography.exceptions import InvalidSignature + + has_crypto = True +except ImportError: + has_crypto = False + + +def _register_default_algorithms(): + """ Registers the algorithms that are implemented by the library """ + register_algorithm('none', NoneAlgorithm()) + register_algorithm('HS256', HMACAlgorithm(hashlib.sha256)) + register_algorithm('HS384', HMACAlgorithm(hashlib.sha384)) + register_algorithm('HS512', HMACAlgorithm(hashlib.sha512)) + + if has_crypto: + register_algorithm('RS256', RSAAlgorithm(hashes.SHA256())) + register_algorithm('RS384', RSAAlgorithm(hashes.SHA384())) + register_algorithm('RS512', RSAAlgorithm(hashes.SHA512())) + + register_algorithm('ES256', ECAlgorithm(hashes.SHA256())) + register_algorithm('ES384', ECAlgorithm(hashes.SHA384())) + register_algorithm('ES512', ECAlgorithm(hashes.SHA512())) + + +class Algorithm(object): + """ The interface for an algorithm used to sign and verify JWTs """ + def prepare_key(self, key): + """ + Performs necessary validation and conversions on the key and returns + the key value in the proper format for sign() and verify() + """ + raise NotImplementedError + + def sign(self, msg, key): + """ + Returns a digital signature for the specified message using the + specified key value + """ + raise NotImplementedError + + def verify(self, msg, key, sig): + """ + Verifies that the specified digital signature is valid for the specified + message and key values. + """ + raise NotImplementedError + + +class NoneAlgorithm(Algorithm): + """ + Placeholder for use when no signing or verification operations are required + """ + def prepare_key(self, key): + return None + + def sign(self, msg, key): + return b'' + + def verify(self, msg, key): + return True + + +class HMACAlgorithm(Algorithm): + """ + Performs signing and verification operations using HMAC and the specified + hash function + """ + def __init__(self, hash_alg): + self.hash_alg = hash_alg + + def prepare_key(self, key): + if not isinstance(key, string_types) and not isinstance(key, bytes): + raise TypeError('Expecting a string- or bytes-formatted key.') + + if isinstance(key, text_type): + key = key.encode('utf-8') + + return key + + def sign(self, msg, key): + return hmac.new(key, msg, self.hash_alg).digest() + + def verify(self, msg, key, sig): + return constant_time_compare(sig, self.sign(msg, key)) + +if has_crypto: + + class RSAAlgorithm(Algorithm): + """ + Performs signing and verification operations using RSASSA-PKCS-v1_5 and + the specified hash function + """ + + def __init__(self, hash_alg): + self.hash_alg = hash_alg + + def prepare_key(self, key): + if isinstance(key, interfaces.RSAPrivateKey) or \ + isinstance(key, interfaces.RSAPublicKey): + return key + + if isinstance(key, string_types): + if isinstance(key, text_type): + key = key.encode('utf-8') + + try: + if key.startswith(b'ssh-rsa'): + key = load_ssh_public_key(key, backend=default_backend()) + else: + key = load_pem_private_key(key, password=None, backend=default_backend()) + except ValueError: + key = load_pem_public_key(key, backend=default_backend()) + else: + raise TypeError('Expecting a PEM-formatted key.') + + return key + + def sign(self, msg, key): + signer = key.signer( + padding.PKCS1v15(), + self.hash_alg + ) + + signer.update(msg) + return signer.finalize() + + def verify(self, msg, key, sig): + verifier = key.verifier( + sig, + padding.PKCS1v15(), + self.hash_alg + ) + + verifier.update(msg) + + try: + verifier.verify() + return True + except InvalidSignature: + return False + + class ECAlgorithm(Algorithm): + """ + Performs signing and verification operations using ECDSA and the + specified hash function + """ + def __init__(self, hash_alg): + self.hash_alg = hash_alg + + def prepare_key(self, key): + if isinstance(key, interfaces.EllipticCurvePrivateKey) or \ + isinstance(key, interfaces.EllipticCurvePublicKey): + return key + + if isinstance(key, string_types): + if isinstance(key, text_type): + key = key.encode('utf-8') + + # Attempt to load key. We don't know if it's + # a Signing Key or a Verifying Key, so we try + # the Verifying Key first. + try: + key = load_pem_public_key(key, backend=default_backend()) + except ValueError: + key = load_pem_private_key(key, password=None, backend=default_backend()) + + else: + raise TypeError('Expecting a PEM-formatted key.') + + return key + + def sign(self, msg, key): + signer = key.signer(ec.ECDSA(self.hash_alg)) + + signer.update(msg) + return signer.finalize() + + def verify(self, msg, key, sig): + verifier = key.verifier(sig, ec.ECDSA(self.hash_alg)) + + verifier.update(msg) + + try: + verifier.verify() + return True + except InvalidSignature: + return False diff --git a/jwt/utils.py b/jwt/utils.py new file mode 100644 index 0000000..e6c1ef3 --- /dev/null +++ b/jwt/utils.py @@ -0,0 +1,14 @@ +import base64 + + +def base64url_decode(input): + rem = len(input) % 4 + + if rem > 0: + input += b'=' * (4 - rem) + + return base64.urlsafe_b64decode(input) + + +def base64url_encode(input): + return base64.urlsafe_b64encode(input).replace(b'=', b'') @@ -1,8 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- import os -import sys import re +import sys + from setuptools import setup diff --git a/tests/test_jwt.py b/tests/test_jwt.py index a57ab31..bd9ca06 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -45,6 +45,10 @@ class TestJWT(unittest.TestCase): self.payload = {'iss': 'jeff', 'exp': utc_timestamp() + 15, 'claim': 'insanity'} + def test_register_algorithm_rejects_non_algorithm_obj(self): + with self.assertRaises(TypeError): + jwt.register_algorithm('AAA123', {}) + def test_encode_decode(self): secret = 'secret' jwt_message = jwt.encode(self.payload, secret) @@ -549,35 +553,15 @@ class TestJWT(unittest.TestCase): load_output = jwt.load(jwt_message) jwt.verify_signature(key=pub_rsakey, *load_output) - def test_rsa_related_signing_methods(self): - if has_crypto: - self.assertTrue('RS256' in jwt.signing_methods) - self.assertTrue('RS384' in jwt.signing_methods) - self.assertTrue('RS512' in jwt.signing_methods) - else: - self.assertFalse('RS256' in jwt.signing_methods) - self.assertFalse('RS384' in jwt.signing_methods) - self.assertFalse('RS512' in jwt.signing_methods) - - def test_rsa_related_verify_methods(self): - if has_crypto: - self.assertTrue('RS256' in jwt.verify_methods) - self.assertTrue('RS384' in jwt.verify_methods) - self.assertTrue('RS512' in jwt.verify_methods) - else: - self.assertFalse('RS256' in jwt.verify_methods) - self.assertFalse('RS384' in jwt.verify_methods) - self.assertFalse('RS512' in jwt.verify_methods) - - def test_rsa_related_key_preparation_methods(self): + def test_rsa_related_algorithms(self): if has_crypto: - self.assertTrue('RS256' in jwt.prepare_key_methods) - self.assertTrue('RS384' in jwt.prepare_key_methods) - self.assertTrue('RS512' in jwt.prepare_key_methods) + self.assertTrue('RS256' in jwt._algorithms) + self.assertTrue('RS384' in jwt._algorithms) + self.assertTrue('RS512' in jwt._algorithms) else: - self.assertFalse('RS256' in jwt.prepare_key_methods) - self.assertFalse('RS384' in jwt.prepare_key_methods) - self.assertFalse('RS512' in jwt.prepare_key_methods) + self.assertFalse('RS256' in jwt._algorithms) + self.assertFalse('RS384' in jwt._algorithms) + self.assertFalse('RS512' in jwt._algorithms) @unittest.skipIf(not has_crypto, "Can't run without cryptography library") def test_encode_decode_with_ecdsa_sha256(self): @@ -669,35 +653,15 @@ class TestJWT(unittest.TestCase): load_output = jwt.load(jwt_message) jwt.verify_signature(key=pub_eckey, *load_output) - def test_ecdsa_related_signing_methods(self): - if has_crypto: - self.assertTrue('ES256' in jwt.signing_methods) - self.assertTrue('ES384' in jwt.signing_methods) - self.assertTrue('ES512' in jwt.signing_methods) - else: - self.assertFalse('ES256' in jwt.signing_methods) - self.assertFalse('ES384' in jwt.signing_methods) - self.assertFalse('ES512' in jwt.signing_methods) - - def test_ecdsa_related_verify_methods(self): - if has_crypto: - self.assertTrue('ES256' in jwt.verify_methods) - self.assertTrue('ES384' in jwt.verify_methods) - self.assertTrue('ES512' in jwt.verify_methods) - else: - self.assertFalse('ES256' in jwt.verify_methods) - self.assertFalse('ES384' in jwt.verify_methods) - self.assertFalse('ES512' in jwt.verify_methods) - - def test_ecdsa_related_key_preparation_methods(self): + def test_ecdsa_related_algorithms(self): if has_crypto: - self.assertTrue('ES256' in jwt.prepare_key_methods) - self.assertTrue('ES384' in jwt.prepare_key_methods) - self.assertTrue('ES512' in jwt.prepare_key_methods) + self.assertTrue('ES256' in jwt._algorithms) + self.assertTrue('ES384' in jwt._algorithms) + self.assertTrue('ES512' in jwt._algorithms) else: - self.assertFalse('ES256' in jwt.prepare_key_methods) - self.assertFalse('ES384' in jwt.prepare_key_methods) - self.assertFalse('ES512' in jwt.prepare_key_methods) + self.assertFalse('ES256' in jwt._algorithms) + self.assertFalse('ES384' in jwt._algorithms) + self.assertFalse('ES512' in jwt._algorithms) def test_check_audience(self): payload = { |