diff options
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r-- | jwt/algorithms.py | 78 |
1 files changed, 47 insertions, 31 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py index f6d990a..7e423b7 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -4,7 +4,8 @@ import json from .compat import constant_time_compare, string_types -from .exceptions import InvalidKeyError +from .exceptions import (InvalidAsymmetricKeyError, InvalidJwkError, + InvalidKeyError) from .utils import ( base64url_decode, base64url_encode, der_to_raw_signature, force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature, @@ -137,6 +138,9 @@ class HMACAlgorithm(Algorithm): self.hash_alg = hash_alg def prepare_key(self, key): + if not isinstance(key, string_types): + raise InvalidKeyError("HMAC secret key must be a string type.") + key = force_bytes(key) invalid_strings = [ @@ -164,7 +168,7 @@ class HMACAlgorithm(Algorithm): obj = json.loads(jwk) if obj.get('kty') != 'oct': - raise InvalidKeyError('Not an HMAC key') + raise InvalidKeyError('Invalid key: Not an HMAC key') return base64url_decode(obj['k']) @@ -194,20 +198,28 @@ if has_crypto: isinstance(key, RSAPublicKey): return key - if isinstance(key, string_types): - key = force_bytes(key) + if not isinstance(key, string_types): + raise InvalidAsymmetricKeyError + + key = force_bytes(key) + if key.startswith(b'ssh-rsa'): try: - if key.startswith(b'ssh-rsa'): - key = load_ssh_public_key(key, backend=default_backend()) - else: - key = load_pem_private_key(key, password=None, backend=default_backend()) + return load_ssh_public_key(key, backend=default_backend()) except ValueError: - key = load_pem_public_key(key, backend=default_backend()) - else: - raise TypeError('Expecting a PEM-formatted key.') + raise InvalidAsymmetricKeyError + + try: + return load_pem_private_key(key, password=None, backend=default_backend()) + except ValueError: + pass + + try: + return load_pem_public_key(key, backend=default_backend()) + except ValueError: + pass - return key + raise InvalidAsymmetricKeyError @staticmethod def to_jwk(key_obj): @@ -241,7 +253,7 @@ if has_crypto: 'e': force_unicode(to_base64url_uint(numbers.e)) } else: - raise InvalidKeyError('Not a public or private key') + raise InvalidKeyError('Invalid key: Expecting a RSAPublicKey or RSAPrivateKey instance.') return json.dumps(obj) @@ -250,22 +262,22 @@ if has_crypto: try: obj = json.loads(jwk) except ValueError: - raise InvalidKeyError('Key is not valid JSON') + raise InvalidJwkError('Key is not valid JSON') if obj.get('kty') != 'RSA': - raise InvalidKeyError('Not an RSA key') + raise InvalidJwkError('Not an RSA key') if 'd' in obj and 'e' in obj and 'n' in obj: # Private key if 'oth' in obj: - raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported') + raise InvalidJwkError('Unsupported RSA private key: > 2 primes not supported') other_props = ['p', 'q', 'dp', 'dq', 'qi'] props_found = [prop in obj for prop in other_props] any_props_found = any(props_found) if any_props_found and not all(props_found): - raise InvalidKeyError('RSA key must include all parameters if any are present besides d') + raise InvalidJwkError('RSA key must include all parameters if any are present besides d') public_numbers = RSAPublicNumbers( from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) @@ -306,7 +318,7 @@ if has_crypto: return numbers.public_key(default_backend()) else: - raise InvalidKeyError('Not a public or private key') + raise InvalidKeyError('Not a valid JWK public or private key') def sign(self, msg, key): signer = key.signer( @@ -349,24 +361,28 @@ if has_crypto: isinstance(key, EllipticCurvePublicKey): return key - if isinstance(key, string_types): - key = force_bytes(key) + if not isinstance(key, string_types): + raise InvalidAsymmetricKeyError + + key = force_bytes(key) - # Attempt to load key. We don't know if it's - # a Signing Key or a Verifying Key, so we try - # the Verifying Key first. + if key.startswith(b'ecdsa-sha2-'): try: - if key.startswith(b'ecdsa-sha2-'): - key = load_ssh_public_key(key, backend=default_backend()) - else: - key = load_pem_public_key(key, backend=default_backend()) + return load_ssh_public_key(key, backend=default_backend()) except ValueError: - key = load_pem_private_key(key, password=None, backend=default_backend()) + raise InvalidAsymmetricKeyError - else: - raise TypeError('Expecting a PEM-formatted key.') + try: + return load_pem_public_key(key, backend=default_backend()) + except ValueError: + pass + + try: + return load_pem_private_key(key, password=None, backend=default_backend()) + except ValueError: + pass - return key + raise InvalidAsymmetricKeyError def sign(self, msg, key): signer = key.signer(ec.ECDSA(self.hash_alg())) |