diff options
author | Michael Davis <mike.davis@webfilings.com> | 2015-04-06 20:14:37 -0500 |
---|---|---|
committer | Michael Davis <mike.davis@webfilings.com> | 2015-04-06 22:58:13 -0500 |
commit | 3ada770fc3247d61bc2d2a50e35c1cb18c7abb2b (patch) | |
tree | 9763f1651ca5bfb8f2293c6b3ab8a99bb62b0d6b | |
parent | a2601ad46433a99c8777a74abeaf4dfd70630d17 (diff) | |
download | pyjwt-3ada770fc3247d61bc2d2a50e35c1cb18c7abb2b.tar.gz |
Add flexible and complete verification options
Attempts to fix #127
-rw-r--r-- | jwt/api.py | 49 | ||||
-rw-r--r-- | tests/test_api.py | 69 |
2 files changed, 105 insertions, 13 deletions
@@ -16,7 +16,7 @@ from .utils import base64url_decode, base64url_encode class PyJWT(object): - def __init__(self, algorithms=None): + def __init__(self, algorithms=None, options=None): self._algorithms = get_default_algorithms() self._valid_algs = set(algorithms) if algorithms is not None else set(self._algorithms) @@ -25,6 +25,22 @@ class PyJWT(object): if key not in self._valid_algs: del self._algorithms[key] + if not options: + options = {} + + self.default_options = { + 'verify_signature': True, + 'verify_exp': True, + 'verify_nbf': True, + 'verify_iat': True, + 'verify_aud': True, + } + + try: + self.options = {k: options[k] if k in options else v for k, v in self.default_options.items()} + except (ValueError, TypeError) as e: + raise TypeError('options must be a dictionary: %s' % e) + def register_algorithm(self, alg_id, alg_obj): """ Registers a new Algorithm for use when creating and verifying tokens. @@ -110,14 +126,16 @@ class PyJWT(object): return b'.'.join(segments) - def decode(self, jwt, key='', verify=True, algorithms=None, **kwargs): + def decode(self, jwt, key='', verify=True, algorithms=None, options=None, **kwargs): payload, signing_input, header, signature = self._load(jwt) if verify: - self._verify_signature(payload, signing_input, header, signature, - key, algorithms) + merged_options = self._merge_options(override_options=options) + if merged_options.get('verify_signature'): + self._verify_signature(payload, signing_input, header, signature, + key, algorithms) - self._validate_claims(payload, **kwargs) + self._validate_claims(payload, options=merged_options, **kwargs) return payload @@ -177,8 +195,8 @@ class PyJWT(object): except KeyError: raise InvalidAlgorithmError('Algorithm not supported') - def _validate_claims(self, payload, verify_expiration=True, leeway=0, - audience=None, issuer=None): + def _validate_claims(self, payload, audience=None, issuer=None, leeway=0, + options=None, **kwargs): if isinstance(leeway, timedelta): leeway = timedelta_total_seconds(leeway) @@ -187,7 +205,7 @@ class PyJWT(object): now = timegm(datetime.utcnow().utctimetuple()) - if 'iat' in payload: + if 'iat' in payload and options.get('verify_iat'): try: iat = int(payload['iat']) except ValueError: @@ -196,7 +214,7 @@ class PyJWT(object): if iat > (now + leeway): raise InvalidIssuedAtError('Issued At claim (iat) cannot be in the future.') - if 'nbf' in payload and verify_expiration: + if 'nbf' in payload and options.get('verify_nbf'): try: nbf = int(payload['nbf']) except ValueError: @@ -205,7 +223,7 @@ class PyJWT(object): if nbf > (now + leeway): raise ImmatureSignatureError('The token is not yet valid (nbf)') - if 'exp' in payload and verify_expiration: + if 'exp' in payload and options.get('verify_exp'): try: exp = int(payload['exp']) except ValueError: @@ -214,7 +232,7 @@ class PyJWT(object): if exp < (now - leeway): raise ExpiredSignatureError('Signature has expired') - if 'aud' in payload: + if 'aud' in payload and options.get('verify_aud'): audience_claims = payload['aud'] if isinstance(audience_claims, string_types): audience_claims = [audience_claims] @@ -233,6 +251,15 @@ class PyJWT(object): if payload.get('iss') != issuer: raise InvalidIssuerError('Invalid issuer') + def _merge_options(self, override_options=None): + if not override_options: + override_options = {} + try: + options = {k: override_options[k] if k in override_options else v for k, v in self.options.items()} + except (ValueError, TypeError) as e: + raise TypeError('options must be a dictionary: %s' % e) + return options + _jwt_global_obj = PyJWT() encode = _jwt_global_obj.encode diff --git a/tests/test_api.py b/tests/test_api.py index f1734b4..c5a715b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -71,6 +71,26 @@ class TestAPI(unittest.TestCase): self.assertNotIn('none', self.jwt.get_algorithms()) self.assertIn('HS256', self.jwt.get_algorithms()) + def test_default_options(self): + self.assertEqual(self.jwt.default_options, self.jwt.options) + + def test_override_options(self): + self.jwt = PyJWT(options={'verify_exp': False, 'verify_nbf': False}) + expected_options = self.jwt.default_options + expected_options['verify_exp'] = False + expected_options['verify_nbf'] = False + self.assertEqual(expected_options, self.jwt.options) + + def test_non_existant_options_dont_exist(self): + self.jwt = PyJWT(options={'verify_iat': False, 'foobar': False}) + expected_options = self.jwt.default_options + expected_options['verify_iat'] = False + self.assertEqual(expected_options, self.jwt.options) + self.assertNotIn('foobar', self.jwt.options) + + def test_options_must_be_dict(self): + self.assertRaises(TypeError, PyJWT, options=object()) + def test_encode_decode(self): secret = 'secret' jwt_message = self.jwt.encode(self.payload, secret) @@ -467,14 +487,14 @@ class TestAPI(unittest.TestCase): secret = 'secret' jwt_message = self.jwt.encode(self.payload, secret) - self.jwt.decode(jwt_message, secret, verify_expiration=False) + self.jwt.decode(jwt_message, secret, options={'verify_exp': False}) def test_decode_skip_notbefore_verification(self): self.payload['nbf'] = time.time() + 10 secret = 'secret' jwt_message = self.jwt.encode(self.payload, secret) - self.jwt.decode(jwt_message, secret, verify_expiration=False) + self.jwt.decode(jwt_message, secret, options={'verify_nbf': False}) def test_decode_with_expiration_with_leeway(self): self.payload['exp'] = utc_timestamp() - 2 @@ -765,6 +785,51 @@ class TestAPI(unittest.TestCase): with self.assertRaises(InvalidIssuerError): self.jwt.decode(token, 'secret', issuer=issuer) + def test_skip_check_audience(self): + payload = { + 'some': 'payload', + 'aud': 'urn:me', + } + token = self.jwt.encode(payload, 'secret') + self.jwt.decode(token, 'secret', options={'verify_aud': False}) + + def test_skip_check_exp(self): + payload = { + 'some': 'payload', + 'exp': datetime.utcnow() - timedelta(days=1) + } + token = self.jwt.encode(payload, 'secret') + self.jwt.decode(token, 'secret', options={'verify_exp': False}) + + def test_skip_check_signature(self): + token = ("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" + ".eyJzb21lIjoicGF5bG9hZCJ9" + ".4twFt5NiznN84AWoo1d7KO1T_yoc0Z6XOpOVswacPZA") + self.jwt.decode(token, 'secret', options={'verify_signature': False}) + + def test_skip_check_iat(self): + payload = { + 'some': 'payload', + 'iat': datetime.utcnow() + timedelta(days=1) + } + token = self.jwt.encode(payload, 'secret') + self.jwt.decode(token, 'secret', options={'verify_iat': False}) + + def test_skip_check_nbf(self): + payload = { + 'some': 'payload', + 'nbf': datetime.utcnow() + timedelta(days=1) + } + token = self.jwt.encode(payload, 'secret') + self.jwt.decode(token, 'secret', options={'verify_nbf': False}) + + def test_decode_options_must_be_dict(self): + payload = { + 'some': 'payload', + } + token = self.jwt.encode(payload, 'secret') + self.assertRaises(TypeError, self.jwt.decode, token, 'secret', options=object()) + def test_custom_json_encoder(self): class CustomJSONEncoder(json.JSONEncoder): |