diff options
author | Mark Adams <mark@markadams.me> | 2015-01-06 08:00:28 -0600 |
---|---|---|
committer | Mark Adams <mark@markadams.me> | 2015-01-18 10:26:25 -0600 |
commit | 1da2d4a52d55d64f77d7c6d6b52cdba555dc0e0b (patch) | |
tree | 4364dabb3fe0a6f3608e67f901e0dc7239b2f967 /jwt/algorithms.py | |
parent | 0afba10cf16834e154a59280de089c30de3d9a61 (diff) | |
download | pyjwt-1da2d4a52d55d64f77d7c6d6b52cdba555dc0e0b.tar.gz |
Fixes #70. Refactored all HMAC, RSA, and EC code into seperate classes in the algorithms module
Added register_algorithm to add new algorithms.
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r-- | jwt/algorithms.py | 200 |
1 files changed, 200 insertions, 0 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py new file mode 100644 index 0000000..2f6f113 --- /dev/null +++ b/jwt/algorithms.py @@ -0,0 +1,200 @@ +import hashlib +import hmac +import sys + +from jwt import register_algorithm + +if sys.version_info >= (3, 0, 0): + unicode = str + basestring = str + +try: + from cryptography.hazmat.primitives import interfaces, hashes + from cryptography.hazmat.primitives.serialization import ( + load_pem_private_key, load_pem_public_key, load_ssh_public_key + ) + from cryptography.hazmat.primitives.asymmetric import ec, padding + from cryptography.hazmat.backends import default_backend + from cryptography.exceptions import InvalidSignature + + has_crypto = True +except ImportError: + has_crypto = False + + +def _register_default_algorithms(): + register_algorithm('none', NoneAlgorithm()) + register_algorithm('HS256', HMACAlgorithm(hashlib.sha256)) + register_algorithm('HS384', HMACAlgorithm(hashlib.sha384)) + register_algorithm('HS512', HMACAlgorithm(hashlib.sha512)) + + if has_crypto: + register_algorithm('RS256', RSAAlgorithm(hashes.SHA256())) + register_algorithm('RS384', RSAAlgorithm(hashes.SHA384())) + register_algorithm('RS512', RSAAlgorithm(hashes.SHA512())) + + register_algorithm('ES256', ECAlgorithm(hashes.SHA256())) + register_algorithm('ES384', ECAlgorithm(hashes.SHA384())) + register_algorithm('ES512', ECAlgorithm(hashes.SHA512())) + + +class Algorithm(object): + def prepare_key(self, key): + pass + + def sign(self, msg, key): + pass + + def verify(self, msg, key, sig): + pass + + +class NoneAlgorithm(Algorithm): + def prepare_key(self, key): + return None + + def sign(self, msg, key): + return b'' + + def verify(self, msg, key): + return True + + +class HMACAlgorithm(Algorithm): + def __init__(self, hash_alg): + self.hash_alg = hash_alg + + def prepare_key(self, key): + if not isinstance(key, basestring) and not isinstance(key, bytes): + raise TypeError('Expecting a string- or bytes-formatted key.') + + if isinstance(key, unicode): + key = key.encode('utf-8') + + return key + + def sign(self, msg, key): + return hmac.new(key, msg, self.hash_alg).digest() + + def verify(self, msg, key, sig): + return self._constant_time_compare(sig, self.sign(msg, key)) + + try: + _constant_time_compare = staticmethod(hmac.compare_digest) + except AttributeError: + # Fallback for Python < 2.7.7 and Python < 3.3 + @staticmethod + def constant_time_compare(val1, val2): + """ + Returns True if the two strings are equal, False otherwise. + + The time taken is independent of the number of characters that match. + """ + if len(val1) != len(val2): + return False + + result = 0 + + if sys.version_info >= (3, 0, 0): + # Bytes are numbers + for x, y in zip(val1, val2): + result |= x ^ y + else: + for x, y in zip(val1, val2): + result |= ord(x) ^ ord(y) + + return result == 0 + +if has_crypto: + + class RSAAlgorithm(Algorithm): + def __init__(self, hash_alg): + self.hash_alg = hash_alg + + def prepare_key(self, key): + if isinstance(key, interfaces.RSAPrivateKey) or \ + isinstance(key, interfaces.RSAPublicKey): + return key + + if isinstance(key, basestring): + if isinstance(key, unicode): + key = key.encode('utf-8') + + try: + if key.startswith(b'ssh-rsa'): + key = load_ssh_public_key(key, backend=default_backend()) + else: + key = load_pem_private_key(key, password=None, backend=default_backend()) + except ValueError: + key = load_pem_public_key(key, backend=default_backend()) + else: + raise TypeError('Expecting a PEM-formatted key.') + + return key + + def sign(self, msg, key): + signer = key.signer( + padding.PKCS1v15(), + self.hash_alg + ) + + signer.update(msg) + return signer.finalize() + + def verify(self, msg, key, sig): + verifier = key.verifier( + sig, + padding.PKCS1v15(), + self.hash_alg + ) + + verifier.update(msg) + + try: + verifier.verify() + return True + except InvalidSignature: + return False + + class ECAlgorithm(Algorithm): + def __init__(self, hash_alg): + self.hash_alg = hash_alg + + def prepare_key(self, key): + if isinstance(key, interfaces.EllipticCurvePrivateKey) or \ + isinstance(key, interfaces.EllipticCurvePublicKey): + return key + + if isinstance(key, basestring): + if isinstance(key, unicode): + key = key.encode('utf-8') + + # Attempt to load key. We don't know if it's + # a Signing Key or a Verifying Key, so we try + # the Verifying Key first. + try: + key = load_pem_public_key(key, backend=default_backend()) + except ValueError: + key = load_pem_private_key(key, password=None, backend=default_backend()) + + else: + raise TypeError('Expecting a PEM-formatted key.') + + return key + + def sign(self, msg, key): + signer = key.signer(ec.ECDSA(self.hash_alg)) + + signer.update(msg) + return signer.finalize() + + def verify(self, msg, key, sig): + verifier = key.verifier(sig, ec.ECDSA(self.hash_alg)) + + verifier.update(msg) + + try: + verifier.verify() + return True + except InvalidSignature: + return False |