summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Adams <mark@markadams.me>2017-03-14 08:04:03 -0500
committerMark Adams <mark@markadams.me>2017-03-14 10:43:58 -0500
commitd04339de54fca44c09ae9105627ab4ae7e0abdb1 (patch)
treedcfdd94edb8d89b2222949ead86694a9e1720be8
parent1710c1524c69c39dfece7a24b87179be5eeff217 (diff)
downloadpyjwt-fix-key-errors.tar.gz
Refactor error handling in Algorithm.prepare_key() methodsfix-key-errors
Our error handling in Algorithm.prepare_key() was previously weird and kind of inconsistent. This change makes a number of improvements: * Refactors RSA and ECDSA prepare_key() methods to reduce nesting and make the code simpler to understand * All calls to Algorithm.prepare_key() return InvalidKeyError (or a subclass) or a valid key instance. * Created a new InvalidAsymmetricKeyError class that is used to provide a standard message when an invalid RSA or ECDSA key is used.
-rw-r--r--CHANGELOG.md1
-rw-r--r--jwt/algorithms.py78
-rw-r--r--jwt/contrib/algorithms/py_ecdsa.py3
-rw-r--r--jwt/contrib/algorithms/pycrypto.py3
-rw-r--r--jwt/exceptions.py10
-rw-r--r--tests/contrib/test_algorithms.py5
-rw-r--r--tests/test_algorithms.py8
7 files changed, 68 insertions, 40 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 060876c..6b599b6 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
-------------------------------------------------------------------------
### Changed
- Add support for ECDSA public keys in RFC 4253 (OpenSSH) format [#244][244]
+- All Algorithm.prepare_key() calls now return either a valid key value or raise InvalidKeyError
- Renamed commandline script `jwt` to `jwt-cli` to avoid issues with the script clobbering the `jwt` module in some circumstances.
- Better error messages when using an algorithm that requires the cryptography package, but it isn't available [#230][230]
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index f6d990a..7e423b7 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -4,7 +4,8 @@ import json
from .compat import constant_time_compare, string_types
-from .exceptions import InvalidKeyError
+from .exceptions import (InvalidAsymmetricKeyError, InvalidJwkError,
+ InvalidKeyError)
from .utils import (
base64url_decode, base64url_encode, der_to_raw_signature,
force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
@@ -137,6 +138,9 @@ class HMACAlgorithm(Algorithm):
self.hash_alg = hash_alg
def prepare_key(self, key):
+ if not isinstance(key, string_types):
+ raise InvalidKeyError("HMAC secret key must be a string type.")
+
key = force_bytes(key)
invalid_strings = [
@@ -164,7 +168,7 @@ class HMACAlgorithm(Algorithm):
obj = json.loads(jwk)
if obj.get('kty') != 'oct':
- raise InvalidKeyError('Not an HMAC key')
+ raise InvalidKeyError('Invalid key: Not an HMAC key')
return base64url_decode(obj['k'])
@@ -194,20 +198,28 @@ if has_crypto:
isinstance(key, RSAPublicKey):
return key
- if isinstance(key, string_types):
- key = force_bytes(key)
+ if not isinstance(key, string_types):
+ raise InvalidAsymmetricKeyError
+
+ key = force_bytes(key)
+ if key.startswith(b'ssh-rsa'):
try:
- if key.startswith(b'ssh-rsa'):
- key = load_ssh_public_key(key, backend=default_backend())
- else:
- key = load_pem_private_key(key, password=None, backend=default_backend())
+ return load_ssh_public_key(key, backend=default_backend())
except ValueError:
- key = load_pem_public_key(key, backend=default_backend())
- else:
- raise TypeError('Expecting a PEM-formatted key.')
+ raise InvalidAsymmetricKeyError
+
+ try:
+ return load_pem_private_key(key, password=None, backend=default_backend())
+ except ValueError:
+ pass
+
+ try:
+ return load_pem_public_key(key, backend=default_backend())
+ except ValueError:
+ pass
- return key
+ raise InvalidAsymmetricKeyError
@staticmethod
def to_jwk(key_obj):
@@ -241,7 +253,7 @@ if has_crypto:
'e': force_unicode(to_base64url_uint(numbers.e))
}
else:
- raise InvalidKeyError('Not a public or private key')
+ raise InvalidKeyError('Invalid key: Expecting a RSAPublicKey or RSAPrivateKey instance.')
return json.dumps(obj)
@@ -250,22 +262,22 @@ if has_crypto:
try:
obj = json.loads(jwk)
except ValueError:
- raise InvalidKeyError('Key is not valid JSON')
+ raise InvalidJwkError('Key is not valid JSON')
if obj.get('kty') != 'RSA':
- raise InvalidKeyError('Not an RSA key')
+ raise InvalidJwkError('Not an RSA key')
if 'd' in obj and 'e' in obj and 'n' in obj:
# Private key
if 'oth' in obj:
- raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
+ raise InvalidJwkError('Unsupported RSA private key: > 2 primes not supported')
other_props = ['p', 'q', 'dp', 'dq', 'qi']
props_found = [prop in obj for prop in other_props]
any_props_found = any(props_found)
if any_props_found and not all(props_found):
- raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
+ raise InvalidJwkError('RSA key must include all parameters if any are present besides d')
public_numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
@@ -306,7 +318,7 @@ if has_crypto:
return numbers.public_key(default_backend())
else:
- raise InvalidKeyError('Not a public or private key')
+ raise InvalidKeyError('Not a valid JWK public or private key')
def sign(self, msg, key):
signer = key.signer(
@@ -349,24 +361,28 @@ if has_crypto:
isinstance(key, EllipticCurvePublicKey):
return key
- if isinstance(key, string_types):
- key = force_bytes(key)
+ if not isinstance(key, string_types):
+ raise InvalidAsymmetricKeyError
+
+ key = force_bytes(key)
- # Attempt to load key. We don't know if it's
- # a Signing Key or a Verifying Key, so we try
- # the Verifying Key first.
+ if key.startswith(b'ecdsa-sha2-'):
try:
- if key.startswith(b'ecdsa-sha2-'):
- key = load_ssh_public_key(key, backend=default_backend())
- else:
- key = load_pem_public_key(key, backend=default_backend())
+ return load_ssh_public_key(key, backend=default_backend())
except ValueError:
- key = load_pem_private_key(key, password=None, backend=default_backend())
+ raise InvalidAsymmetricKeyError
- else:
- raise TypeError('Expecting a PEM-formatted key.')
+ try:
+ return load_pem_public_key(key, backend=default_backend())
+ except ValueError:
+ pass
+
+ try:
+ return load_pem_private_key(key, password=None, backend=default_backend())
+ except ValueError:
+ pass
- return key
+ raise InvalidAsymmetricKeyError
def sign(self, msg, key):
signer = key.signer(ec.ECDSA(self.hash_alg()))
diff --git a/jwt/contrib/algorithms/py_ecdsa.py b/jwt/contrib/algorithms/py_ecdsa.py
index bf0dea5..23016fc 100644
--- a/jwt/contrib/algorithms/py_ecdsa.py
+++ b/jwt/contrib/algorithms/py_ecdsa.py
@@ -7,6 +7,7 @@ import ecdsa
from jwt.algorithms import Algorithm
from jwt.compat import string_types, text_type
+from jwt.exceptions import InvalidAsymmetricKeyError
class ECAlgorithm(Algorithm):
@@ -44,7 +45,7 @@ class ECAlgorithm(Algorithm):
key = ecdsa.SigningKey.from_pem(key)
else:
- raise TypeError('Expecting a PEM-formatted key.')
+ raise InvalidAsymmetricKeyError('Expecting a PEM-formatted key.')
return key
diff --git a/jwt/contrib/algorithms/pycrypto.py b/jwt/contrib/algorithms/pycrypto.py
index e6afaa5..fe477db 100644
--- a/jwt/contrib/algorithms/pycrypto.py
+++ b/jwt/contrib/algorithms/pycrypto.py
@@ -7,6 +7,7 @@ from Crypto.Signature import PKCS1_v1_5
from jwt.algorithms import Algorithm
from jwt.compat import string_types, text_type
+from jwt.exceptions import InvalidAsymmetricKeyError
class RSAAlgorithm(Algorithm):
@@ -36,7 +37,7 @@ class RSAAlgorithm(Algorithm):
key = RSA.importKey(key)
else:
- raise TypeError('Expecting a PEM- or RSA-formatted key.')
+ raise InvalidAsymmetricKeyError
return key
diff --git a/jwt/exceptions.py b/jwt/exceptions.py
index 31177a0..48a6e16 100644
--- a/jwt/exceptions.py
+++ b/jwt/exceptions.py
@@ -26,10 +26,18 @@ class ImmatureSignatureError(InvalidTokenError):
pass
-class InvalidKeyError(Exception):
+class InvalidKeyError(ValueError):
pass
+class InvalidAsymmetricKeyError(InvalidKeyError):
+ message = 'Invalid key: Keys must be in PEM or RFC 4253 format.'
+
+
+class InvalidJwkError(InvalidKeyError):
+ message = 'Invalid key: Keys must be in JWK format.'
+
+
class InvalidAlgorithmError(InvalidTokenError):
pass
diff --git a/tests/contrib/test_algorithms.py b/tests/contrib/test_algorithms.py
index 6d5ca75..b774ad1 100644
--- a/tests/contrib/test_algorithms.py
+++ b/tests/contrib/test_algorithms.py
@@ -1,5 +1,6 @@
import base64
+from jwt.exceptions import InvalidAsymmetricKeyError
from jwt.utils import force_bytes, force_unicode
import pytest
@@ -36,7 +37,7 @@ class TestPycryptoAlgorithms:
def test_rsa_should_reject_non_string_key(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
- with pytest.raises(TypeError):
+ with pytest.raises(InvalidAsymmetricKeyError):
algo.prepare_key(None)
def test_rsa_sign_should_generate_correct_signature_value(self):
@@ -117,7 +118,7 @@ class TestEcdsaAlgorithms:
def test_ec_should_reject_non_string_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
- with pytest.raises(TypeError):
+ with pytest.raises(InvalidAsymmetricKeyError):
algo.prepare_key(None)
def test_ec_should_accept_unicode_key(self):
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index 11d8cd0..69501ee 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -58,11 +58,11 @@ class TestAlgorithms:
def test_hmac_should_reject_nonstring_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
- with pytest.raises(TypeError) as context:
+ with pytest.raises(InvalidKeyError) as context:
algo.prepare_key(object())
exception = context.value
- assert str(exception) == 'Expected a string value'
+ assert str(exception) == 'HMAC secret key must be a string type.'
def test_hmac_should_accept_unicode_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
@@ -144,7 +144,7 @@ class TestAlgorithms:
def test_rsa_should_reject_non_string_key(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
- with pytest.raises(TypeError):
+ with pytest.raises(InvalidKeyError):
algo.prepare_key(None)
@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
@@ -358,7 +358,7 @@ class TestAlgorithms:
def test_ec_should_reject_non_string_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
- with pytest.raises(TypeError):
+ with pytest.raises(InvalidKeyError):
algo.prepare_key(None)
@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')