diff options
author | José Padilla <jpadilla@webapplicate.com> | 2019-10-21 22:38:34 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-10-21 22:38:34 -0400 |
commit | 11ac89474b1179925c76450fcc4b3d2042c45f19 (patch) | |
tree | dda6b15326cab750f5e52ee57e3125bb5d0a7eee /jwt/algorithms.py | |
parent | ae080f472c913ad94456fd9e10b05ec2d038b7cc (diff) | |
download | pyjwt-11ac89474b1179925c76450fcc4b3d2042c45f19.tar.gz |
DX Tweaks (#450)
* Setup pre-commit hooks
* Run initial `tox -e lint`
* Fix package name
* Fix .travis.yml
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r-- | jwt/algorithms.py | 251 |
1 files changed, 154 insertions, 97 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 1343688..293a470 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -2,26 +2,39 @@ import hashlib import hmac import json - from .compat import constant_time_compare, string_types from .exceptions import InvalidKeyError from .utils import ( - base64url_decode, base64url_encode, der_to_raw_signature, - force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature, - to_base64url_uint + base64url_decode, + base64url_encode, + der_to_raw_signature, + force_bytes, + force_unicode, + from_base64url_uint, + raw_to_der_signature, + to_base64url_uint, ) try: from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.serialization import ( - load_pem_private_key, load_pem_public_key, load_ssh_public_key + load_pem_private_key, + load_pem_public_key, + load_ssh_public_key, ) from cryptography.hazmat.primitives.asymmetric.rsa import ( - RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers, - rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp + RSAPrivateKey, + RSAPublicKey, + RSAPrivateNumbers, + RSAPublicNumbers, + rsa_recover_prime_factors, + rsa_crt_dmp1, + rsa_crt_dmq1, + rsa_crt_iqmp, ) from cryptography.hazmat.primitives.asymmetric.ec import ( - EllipticCurvePrivateKey, EllipticCurvePublicKey + EllipticCurvePrivateKey, + EllipticCurvePublicKey, ) from cryptography.hazmat.primitives.asymmetric import ec, padding from cryptography.hazmat.backends import default_backend @@ -31,8 +44,20 @@ try: except ImportError: has_crypto = False -requires_cryptography = set(['RS256', 'RS384', 'RS512', 'ES256', 'ES384', - 'ES521', 'ES512', 'PS256', 'PS384', 'PS512']) +requires_cryptography = set( + [ + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES521", + "ES512", + "PS256", + "PS384", + "PS512", + ] +) def get_default_algorithms(): @@ -40,25 +65,29 @@ def get_default_algorithms(): Returns the algorithms that are implemented by the library. """ default_algorithms = { - 'none': NoneAlgorithm(), - 'HS256': HMACAlgorithm(HMACAlgorithm.SHA256), - 'HS384': HMACAlgorithm(HMACAlgorithm.SHA384), - 'HS512': HMACAlgorithm(HMACAlgorithm.SHA512) + "none": NoneAlgorithm(), + "HS256": HMACAlgorithm(HMACAlgorithm.SHA256), + "HS384": HMACAlgorithm(HMACAlgorithm.SHA384), + "HS512": HMACAlgorithm(HMACAlgorithm.SHA512), } if has_crypto: - default_algorithms.update({ - 'RS256': RSAAlgorithm(RSAAlgorithm.SHA256), - 'RS384': RSAAlgorithm(RSAAlgorithm.SHA384), - 'RS512': RSAAlgorithm(RSAAlgorithm.SHA512), - 'ES256': ECAlgorithm(ECAlgorithm.SHA256), - 'ES384': ECAlgorithm(ECAlgorithm.SHA384), - 'ES521': ECAlgorithm(ECAlgorithm.SHA512), - 'ES512': ECAlgorithm(ECAlgorithm.SHA512), # Backward compat for #219 fix - 'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), - 'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), - 'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512) - }) + default_algorithms.update( + { + "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), + "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), + "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), + "ES256": ECAlgorithm(ECAlgorithm.SHA256), + "ES384": ECAlgorithm(ECAlgorithm.SHA384), + "ES521": ECAlgorithm(ECAlgorithm.SHA512), + "ES512": ECAlgorithm( + ECAlgorithm.SHA512 + ), # Backward compat for #219 fix + "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), + "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), + "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), + } + ) return default_algorithms @@ -67,6 +96,7 @@ class Algorithm(object): """ The interface for an algorithm used to sign and verify tokens. """ + def prepare_key(self, key): """ Performs necessary validation and conversions on the key and returns @@ -108,8 +138,9 @@ class NoneAlgorithm(Algorithm): Placeholder for use when no signing or verification operations are required. """ + def prepare_key(self, key): - if key == '': + if key == "": key = None if key is not None: @@ -118,7 +149,7 @@ class NoneAlgorithm(Algorithm): return key def sign(self, msg, key): - return b'' + return b"" def verify(self, msg, key, sig): return False @@ -129,6 +160,7 @@ class HMACAlgorithm(Algorithm): Performs signing and verification operations using HMAC and the specified hash function. """ + SHA256 = hashlib.sha256 SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 @@ -140,34 +172,37 @@ class HMACAlgorithm(Algorithm): key = force_bytes(key) invalid_strings = [ - b'-----BEGIN PUBLIC KEY-----', - b'-----BEGIN CERTIFICATE-----', - b'-----BEGIN RSA PUBLIC KEY-----', - b'ssh-rsa' + b"-----BEGIN PUBLIC KEY-----", + b"-----BEGIN CERTIFICATE-----", + b"-----BEGIN RSA PUBLIC KEY-----", + b"ssh-rsa", ] if any([string_value in key for string_value in invalid_strings]): raise InvalidKeyError( - 'The specified key is an asymmetric key or x509 certificate and' - ' should not be used as an HMAC secret.') + "The specified key is an asymmetric key or x509 certificate and" + " should not be used as an HMAC secret." + ) return key @staticmethod def to_jwk(key_obj): - return json.dumps({ - 'k': force_unicode(base64url_encode(force_bytes(key_obj))), - 'kty': 'oct' - }) + return json.dumps( + { + "k": force_unicode(base64url_encode(force_bytes(key_obj))), + "kty": "oct", + } + ) @staticmethod def from_jwk(jwk): obj = json.loads(jwk) - if obj.get('kty') != 'oct': - raise InvalidKeyError('Not an HMAC key') + if obj.get("kty") != "oct": + raise InvalidKeyError("Not an HMAC key") - return base64url_decode(obj['k']) + return base64url_decode(obj["k"]) def sign(self, msg, key): return hmac.new(key, msg, self.hash_alg).digest() @@ -176,13 +211,14 @@ class HMACAlgorithm(Algorithm): return constant_time_compare(sig, self.sign(msg, key)) -if has_crypto: +if has_crypto: # noqa: C901 class RSAAlgorithm(Algorithm): """ Performs signing and verification operations using RSASSA-PKCS-v1_5 and the specified hash function. """ + SHA256 = hashes.SHA256 SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 @@ -191,22 +227,25 @@ if has_crypto: self.hash_alg = hash_alg def prepare_key(self, key): - if isinstance(key, RSAPrivateKey) or \ - isinstance(key, RSAPublicKey): + if isinstance(key, RSAPrivateKey) or isinstance(key, RSAPublicKey): return key if isinstance(key, string_types): key = force_bytes(key) try: - if key.startswith(b'ssh-rsa'): - key = load_ssh_public_key(key, backend=default_backend()) + if key.startswith(b"ssh-rsa"): + key = load_ssh_public_key( + key, backend=default_backend() + ) else: - key = load_pem_private_key(key, password=None, backend=default_backend()) + key = load_pem_private_key( + key, password=None, backend=default_backend() + ) except ValueError: key = load_pem_public_key(key, backend=default_backend()) else: - raise TypeError('Expecting a PEM-formatted key.') + raise TypeError("Expecting a PEM-formatted key.") return key @@ -214,35 +253,39 @@ if has_crypto: def to_jwk(key_obj): obj = None - if getattr(key_obj, 'private_numbers', None): + if getattr(key_obj, "private_numbers", None): # Private key numbers = key_obj.private_numbers() obj = { - 'kty': 'RSA', - 'key_ops': ['sign'], - 'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)), - 'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)), - 'd': force_unicode(to_base64url_uint(numbers.d)), - 'p': force_unicode(to_base64url_uint(numbers.p)), - 'q': force_unicode(to_base64url_uint(numbers.q)), - 'dp': force_unicode(to_base64url_uint(numbers.dmp1)), - 'dq': force_unicode(to_base64url_uint(numbers.dmq1)), - 'qi': force_unicode(to_base64url_uint(numbers.iqmp)) + "kty": "RSA", + "key_ops": ["sign"], + "n": force_unicode( + to_base64url_uint(numbers.public_numbers.n) + ), + "e": force_unicode( + to_base64url_uint(numbers.public_numbers.e) + ), + "d": force_unicode(to_base64url_uint(numbers.d)), + "p": force_unicode(to_base64url_uint(numbers.p)), + "q": force_unicode(to_base64url_uint(numbers.q)), + "dp": force_unicode(to_base64url_uint(numbers.dmp1)), + "dq": force_unicode(to_base64url_uint(numbers.dmq1)), + "qi": force_unicode(to_base64url_uint(numbers.iqmp)), } - elif getattr(key_obj, 'verify', None): + elif getattr(key_obj, "verify", None): # Public key numbers = key_obj.public_numbers() obj = { - 'kty': 'RSA', - 'key_ops': ['verify'], - 'n': force_unicode(to_base64url_uint(numbers.n)), - 'e': force_unicode(to_base64url_uint(numbers.e)) + "kty": "RSA", + "key_ops": ["verify"], + "n": force_unicode(to_base64url_uint(numbers.n)), + "e": force_unicode(to_base64url_uint(numbers.e)), } else: - raise InvalidKeyError('Not a public or private key') + raise InvalidKeyError("Not a public or private key") return json.dumps(obj) @@ -251,39 +294,44 @@ if has_crypto: try: obj = json.loads(jwk) except ValueError: - raise InvalidKeyError('Key is not valid JSON') + raise InvalidKeyError("Key is not valid JSON") - if obj.get('kty') != 'RSA': - raise InvalidKeyError('Not an RSA key') + if obj.get("kty") != "RSA": + raise InvalidKeyError("Not an RSA key") - if 'd' in obj and 'e' in obj and 'n' in obj: + if "d" in obj and "e" in obj and "n" in obj: # Private key - if 'oth' in obj: - raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported') + if "oth" in obj: + raise InvalidKeyError( + "Unsupported RSA private key: > 2 primes not supported" + ) - other_props = ['p', 'q', 'dp', 'dq', 'qi'] + other_props = ["p", "q", "dp", "dq", "qi"] props_found = [prop in obj for prop in other_props] any_props_found = any(props_found) if any_props_found and not all(props_found): - raise InvalidKeyError('RSA key must include all parameters if any are present besides d') + raise InvalidKeyError( + "RSA key must include all parameters if any are present besides d" + ) public_numbers = RSAPublicNumbers( - from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) + from_base64url_uint(obj["e"]), + from_base64url_uint(obj["n"]), ) if any_props_found: numbers = RSAPrivateNumbers( - d=from_base64url_uint(obj['d']), - p=from_base64url_uint(obj['p']), - q=from_base64url_uint(obj['q']), - dmp1=from_base64url_uint(obj['dp']), - dmq1=from_base64url_uint(obj['dq']), - iqmp=from_base64url_uint(obj['qi']), - public_numbers=public_numbers + d=from_base64url_uint(obj["d"]), + p=from_base64url_uint(obj["p"]), + q=from_base64url_uint(obj["q"]), + dmp1=from_base64url_uint(obj["dp"]), + dmq1=from_base64url_uint(obj["dq"]), + iqmp=from_base64url_uint(obj["qi"]), + public_numbers=public_numbers, ) else: - d = from_base64url_uint(obj['d']) + d = from_base64url_uint(obj["d"]) p, q = rsa_recover_prime_factors( public_numbers.n, d, public_numbers.e ) @@ -295,19 +343,20 @@ if has_crypto: dmp1=rsa_crt_dmp1(d, p), dmq1=rsa_crt_dmq1(d, q), iqmp=rsa_crt_iqmp(p, q), - public_numbers=public_numbers + public_numbers=public_numbers, ) return numbers.private_key(default_backend()) - elif 'n' in obj and 'e' in obj: + elif "n" in obj and "e" in obj: # Public key numbers = RSAPublicNumbers( - from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) + from_base64url_uint(obj["e"]), + from_base64url_uint(obj["n"]), ) return numbers.public_key(default_backend()) else: - raise InvalidKeyError('Not a public or private key') + raise InvalidKeyError("Not a public or private key") def sign(self, msg, key): return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) @@ -324,6 +373,7 @@ if has_crypto: Performs signing and verification operations using ECDSA and the specified hash function """ + SHA256 = hashes.SHA256 SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 @@ -332,8 +382,9 @@ if has_crypto: self.hash_alg = hash_alg def prepare_key(self, key): - if isinstance(key, EllipticCurvePrivateKey) or \ - isinstance(key, EllipticCurvePublicKey): + if isinstance(key, EllipticCurvePrivateKey) or isinstance( + key, EllipticCurvePublicKey + ): return key if isinstance(key, string_types): @@ -343,15 +394,21 @@ if has_crypto: # a Signing Key or a Verifying Key, so we try # the Verifying Key first. try: - if key.startswith(b'ecdsa-sha2-'): - key = load_ssh_public_key(key, backend=default_backend()) + if key.startswith(b"ecdsa-sha2-"): + key = load_ssh_public_key( + key, backend=default_backend() + ) else: - key = load_pem_public_key(key, backend=default_backend()) + key = load_pem_public_key( + key, backend=default_backend() + ) except ValueError: - key = load_pem_private_key(key, password=None, backend=default_backend()) + key = load_pem_private_key( + key, password=None, backend=default_backend() + ) else: - raise TypeError('Expecting a PEM-formatted key.') + raise TypeError("Expecting a PEM-formatted key.") return key @@ -382,9 +439,9 @@ if has_crypto: msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size + salt_length=self.hash_alg.digest_size, ), - self.hash_alg() + self.hash_alg(), ) def verify(self, msg, key, sig): @@ -394,9 +451,9 @@ if has_crypto: msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size + salt_length=self.hash_alg.digest_size, ), - self.hash_alg() + self.hash_alg(), ) return True except InvalidSignature: |