diff options
author | Mauricio Aizaga <mauricioaizaga@gmail.com> | 2015-10-23 10:35:23 -0500 |
---|---|---|
committer | Mauricio Aizaga <mauricioaizaga@gmail.com> | 2015-10-23 10:35:23 -0500 |
commit | 3a50a425822fa3663643a98c6f6692f5acd8194a (patch) | |
tree | c9cbba520ff6f1ae7201182ac8b6e5aa3f73af91 | |
parent | 9f39f38a34569541691bf1b068fc8b3717f13d64 (diff) | |
download | pyjwt-3a50a425822fa3663643a98c6f6692f5acd8194a.tar.gz |
binary_type verification added to make the code more future-proof
-rw-r--r-- | jwt/api_jws.py | 9 | ||||
-rw-r--r-- | jwt/compat.py | 2 | ||||
-rw-r--r-- | tests/test_api_jws.py | 10 |
3 files changed, 13 insertions, 8 deletions
diff --git a/jwt/api_jws.py b/jwt/api_jws.py index dcb2647..177f5ff 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -5,7 +5,7 @@ import warnings from collections import Mapping from .algorithms import Algorithm, get_default_algorithms # NOQA -from .compat import string_types, text_type +from .compat import binary_type, string_types, text_type from .exceptions import DecodeError, InvalidAlgorithmError, InvalidTokenError from .utils import base64url_decode, base64url_encode, merge_dict @@ -135,12 +135,13 @@ class PyJWS(object): if isinstance(jwt, text_type): jwt = jwt.encode('utf-8') + if not issubclass(type(jwt), binary_type): + raise DecodeError("Invalid token type. Token must be a {0}".format( + binary_type)) + try: signing_input, crypto_segment = jwt.rsplit(b'.', 1) header_segment, payload_segment = signing_input.split(b'.', 1) - except AttributeError: - raise DecodeError('Invalid payload type. It should be a {0} not {1}'.format( - text_type, type(jwt))) except ValueError: raise DecodeError('Not enough segments') diff --git a/jwt/compat.py b/jwt/compat.py index 11d423b..8f6bfed 100644 --- a/jwt/compat.py +++ b/jwt/compat.py @@ -13,9 +13,11 @@ PY3 = sys.version_info[0] == 3 if PY3: string_types = str, text_type = str + binary_type = bytes else: string_types = basestring, text_type = unicode + binary_type = str def timedelta_total_seconds(delta): diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index c08d3e4..c56ec4b 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -122,23 +122,25 @@ class TestJWS: exception = context.value assert str(exception) == 'Not enough segments' - def test_decode_invalid_payload_type_is_none(self, jws): + def test_decode_invalid_token_type_is_none(self, jws): example_jws = None example_secret = 'secret' with pytest.raises(DecodeError) as context: jws.decode(example_jws, example_secret) - assert 'Invalid payload type' in str(context.value) + exception = context.value + assert 'Invalid token type' in str(exception) - def test_decode_invalid_payload_type_is_int(self, jws): + def test_decode_invalid_token_type_is_int(self, jws): example_jws = 123 example_secret = 'secret' with pytest.raises(DecodeError) as context: jws.decode(example_jws, example_secret) - assert 'Invalid payload type' in str(context.value) + exception = context.value + assert 'Invalid token type' in str(exception) def test_decode_with_non_mapping_header_throws_exception(self, jws): secret = 'secret' |