summaryrefslogtreecommitdiff
path: root/jwt/algorithms.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r--jwt/algorithms.py78
1 files changed, 47 insertions, 31 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index f6d990a..7e423b7 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -4,7 +4,8 @@ import json
from .compat import constant_time_compare, string_types
-from .exceptions import InvalidKeyError
+from .exceptions import (InvalidAsymmetricKeyError, InvalidJwkError,
+ InvalidKeyError)
from .utils import (
base64url_decode, base64url_encode, der_to_raw_signature,
force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
@@ -137,6 +138,9 @@ class HMACAlgorithm(Algorithm):
self.hash_alg = hash_alg
def prepare_key(self, key):
+ if not isinstance(key, string_types):
+ raise InvalidKeyError("HMAC secret key must be a string type.")
+
key = force_bytes(key)
invalid_strings = [
@@ -164,7 +168,7 @@ class HMACAlgorithm(Algorithm):
obj = json.loads(jwk)
if obj.get('kty') != 'oct':
- raise InvalidKeyError('Not an HMAC key')
+ raise InvalidKeyError('Invalid key: Not an HMAC key')
return base64url_decode(obj['k'])
@@ -194,20 +198,28 @@ if has_crypto:
isinstance(key, RSAPublicKey):
return key
- if isinstance(key, string_types):
- key = force_bytes(key)
+ if not isinstance(key, string_types):
+ raise InvalidAsymmetricKeyError
+
+ key = force_bytes(key)
+ if key.startswith(b'ssh-rsa'):
try:
- if key.startswith(b'ssh-rsa'):
- key = load_ssh_public_key(key, backend=default_backend())
- else:
- key = load_pem_private_key(key, password=None, backend=default_backend())
+ return load_ssh_public_key(key, backend=default_backend())
except ValueError:
- key = load_pem_public_key(key, backend=default_backend())
- else:
- raise TypeError('Expecting a PEM-formatted key.')
+ raise InvalidAsymmetricKeyError
+
+ try:
+ return load_pem_private_key(key, password=None, backend=default_backend())
+ except ValueError:
+ pass
+
+ try:
+ return load_pem_public_key(key, backend=default_backend())
+ except ValueError:
+ pass
- return key
+ raise InvalidAsymmetricKeyError
@staticmethod
def to_jwk(key_obj):
@@ -241,7 +253,7 @@ if has_crypto:
'e': force_unicode(to_base64url_uint(numbers.e))
}
else:
- raise InvalidKeyError('Not a public or private key')
+ raise InvalidKeyError('Invalid key: Expecting a RSAPublicKey or RSAPrivateKey instance.')
return json.dumps(obj)
@@ -250,22 +262,22 @@ if has_crypto:
try:
obj = json.loads(jwk)
except ValueError:
- raise InvalidKeyError('Key is not valid JSON')
+ raise InvalidJwkError('Key is not valid JSON')
if obj.get('kty') != 'RSA':
- raise InvalidKeyError('Not an RSA key')
+ raise InvalidJwkError('Not an RSA key')
if 'd' in obj and 'e' in obj and 'n' in obj:
# Private key
if 'oth' in obj:
- raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
+ raise InvalidJwkError('Unsupported RSA private key: > 2 primes not supported')
other_props = ['p', 'q', 'dp', 'dq', 'qi']
props_found = [prop in obj for prop in other_props]
any_props_found = any(props_found)
if any_props_found and not all(props_found):
- raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
+ raise InvalidJwkError('RSA key must include all parameters if any are present besides d')
public_numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
@@ -306,7 +318,7 @@ if has_crypto:
return numbers.public_key(default_backend())
else:
- raise InvalidKeyError('Not a public or private key')
+ raise InvalidKeyError('Not a valid JWK public or private key')
def sign(self, msg, key):
signer = key.signer(
@@ -349,24 +361,28 @@ if has_crypto:
isinstance(key, EllipticCurvePublicKey):
return key
- if isinstance(key, string_types):
- key = force_bytes(key)
+ if not isinstance(key, string_types):
+ raise InvalidAsymmetricKeyError
+
+ key = force_bytes(key)
- # Attempt to load key. We don't know if it's
- # a Signing Key or a Verifying Key, so we try
- # the Verifying Key first.
+ if key.startswith(b'ecdsa-sha2-'):
try:
- if key.startswith(b'ecdsa-sha2-'):
- key = load_ssh_public_key(key, backend=default_backend())
- else:
- key = load_pem_public_key(key, backend=default_backend())
+ return load_ssh_public_key(key, backend=default_backend())
except ValueError:
- key = load_pem_private_key(key, password=None, backend=default_backend())
+ raise InvalidAsymmetricKeyError
- else:
- raise TypeError('Expecting a PEM-formatted key.')
+ try:
+ return load_pem_public_key(key, backend=default_backend())
+ except ValueError:
+ pass
+
+ try:
+ return load_pem_private_key(key, password=None, backend=default_backend())
+ except ValueError:
+ pass
- return key
+ raise InvalidAsymmetricKeyError
def sign(self, msg, key):
signer = key.signer(ec.ECDSA(self.hash_alg()))