diff options
author | Landon GB <landogbland@gmail.com> | 2016-11-30 09:33:46 -0700 |
---|---|---|
committer | Landon GB <landogbland@gmail.com> | 2016-11-30 09:33:46 -0700 |
commit | 0550fa347ed29e2ffcd8f26a3adcf0ddcae6fdc3 (patch) | |
tree | c3b7b81d13b8f6b7bdaefe0788ba904ff99b3454 | |
parent | 8520c8f0e0c9ea13df785e8143ee193f088c008a (diff) | |
download | pyjwt-0550fa347ed29e2ffcd8f26a3adcf0ddcae6fdc3.tar.gz |
Changes per code review
-rw-r--r-- | jwt/algorithms.py | 53 | ||||
-rw-r--r-- | jwt/api_jws.py | 10 | ||||
-rw-r--r-- | tests/test_api_jws.py | 2 |
3 files changed, 21 insertions, 44 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 77a910d..8360ed9 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -31,34 +31,8 @@ try: except ImportError: has_crypto = False - -def _get_crypto_algorithms(): - crypto_algorithms = { - 'RS256': None, - 'RS384': None, - 'RS512': None, - 'ES256': None, - 'ES384': None, - 'ES521': None, - 'ES512': None, - 'PS256': None, - 'PS384': None, - 'PS512': None - } - - if has_crypto: - crypto_algorithms['RS256'] = RSAAlgorithm(RSAAlgorithm.SHA256) - crypto_algorithms['RS384'] = RSAAlgorithm(RSAAlgorithm.SHA384) - crypto_algorithms['RS512'] = RSAAlgorithm(RSAAlgorithm.SHA512) - crypto_algorithms['ES256'] = ECAlgorithm(ECAlgorithm.SHA256) - crypto_algorithms['ES384'] = ECAlgorithm(ECAlgorithm.SHA384) - crypto_algorithms['ES521'] = ECAlgorithm(ECAlgorithm.SHA512) - crypto_algorithms['ES512'] = ECAlgorithm(ECAlgorithm.SHA512) - crypto_algorithms['PS256'] = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256) - crypto_algorithms['PS384'] = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384) - crypto_algorithms['PS512'] = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512) - - return crypto_algorithms +requires_cryptography = {'RS256', 'RS384', 'RS512', 'ES256', 'ES384', 'ES521', + 'ES512', 'PS256', 'PS384', 'PS512'} def get_default_algorithms(): @@ -73,21 +47,22 @@ def get_default_algorithms(): } if has_crypto: - crypto_algorithms = _get_crypto_algorithms() - default_algorithms.update(crypto_algorithms) + default_algorithms.update({ + 'RS256': RSAAlgorithm(RSAAlgorithm.SHA256), + 'RS384': RSAAlgorithm(RSAAlgorithm.SHA384), + 'RS512': RSAAlgorithm(RSAAlgorithm.SHA512), + 'ES256': ECAlgorithm(ECAlgorithm.SHA256), + 'ES384': ECAlgorithm(ECAlgorithm.SHA384), + 'ES521': ECAlgorithm(ECAlgorithm.SHA512), + 'ES512': ECAlgorithm(ECAlgorithm.SHA512), # Backward compat for #219 fix + 'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), + 'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), + 'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512) + }) return default_algorithms -def get_crypto_algorithms(): - """ - Returns a set of algorithm names that require the cryptography package to - be installed in order to use. - """ - crypto_algorithms = _get_crypto_algorithms().keys() - return set(crypto_algorithms) - - class Algorithm(object): """ The interface for an algorithm used to sign and verify tokens. diff --git a/jwt/api_jws.py b/jwt/api_jws.py index df54e3f..66b6a67 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -5,7 +5,7 @@ import warnings from collections import Mapping from .algorithms import ( - Algorithm, get_crypto_algorithms, get_default_algorithms, has_crypto # NOQA + Algorithm, get_default_algorithms, has_crypto, requires_cryptography # NOQA ) from .compat import binary_type, string_types, text_type from .exceptions import DecodeError, InvalidAlgorithmError, InvalidTokenError @@ -103,9 +103,11 @@ class PyJWS(object): signature = alg_obj.sign(signing_input, key) except KeyError: - if not has_crypto and algorithm in get_crypto_algorithms(): - raise NotImplementedError('"cryptography" package must be ' - 'installed to use this algorithm') + if not has_crypto and algorithm in requires_cryptography: + raise NotImplementedError( + "Algorithm '%s' could not be found. Do you have cryptography " + "installed?" % algorithm + ) else: raise NotImplementedError('Algorithm not supported') diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index d0876a4..053dd11 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -303,7 +303,7 @@ class TestJWS: with pytest.raises(NotImplementedError): jws.encode(payload, 'secret', algorithm='HS1024') - @pytest.mark.skipif(has_crypto, reason='Tests better errors if crypography not installed') + @pytest.mark.skipif(has_crypto, reason='Scenario requires cryptography to not be installed') def test_missing_crypto_library_better_error_messages(self, jws, payload): with pytest.raises(NotImplementedError) as excinfo: jws.encode(payload, 'secret', algorithm='RS256') |