diff options
author | José Padilla <jpadilla@webapplicate.com> | 2015-04-14 10:05:56 -0400 |
---|---|---|
committer | José Padilla <jpadilla@webapplicate.com> | 2015-04-14 10:05:56 -0400 |
commit | b39b9a7887c2feab1058fa371f761e1e27f6da1d (patch) | |
tree | ca8babf9734bb3deaa90c206655f504279c6c7bb | |
parent | 5fd54bec162a25ae9dc0de4c476dc7c51bc1017b (diff) | |
parent | 90577f7cef93c4d9e8b6168212f29129fa2b3f71 (diff) | |
download | pyjwt-b39b9a7887c2feab1058fa371f761e1e27f6da1d.tar.gz |
Merge pull request #135 from mark-adams/minor-updates
Minor refactorings to make things a little cleaner
-rw-r--r-- | jwt/api.py | 24 | ||||
-rw-r--r-- | jwt/utils.py | 13 | ||||
-rw-r--r-- | tests/test_algorithms.py | 18 | ||||
-rw-r--r-- | tests/test_api.py | 30 |
4 files changed, 50 insertions, 35 deletions
@@ -13,7 +13,7 @@ from .exceptions import ( InvalidAlgorithmError, InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError ) -from .utils import base64url_decode, base64url_encode +from .utils import base64url_decode, base64url_encode, merge_dict class PyJWT(object): @@ -29,7 +29,7 @@ class PyJWT(object): if not options: options = {} - self.default_options = { + default_options = { 'verify_signature': True, 'verify_exp': True, 'verify_nbf': True, @@ -37,7 +37,7 @@ class PyJWT(object): 'verify_aud': True, } - self.options = self._merge_options(self.default_options, options) + self.options = merge_dict(default_options, options) def register_algorithm(self, alg_id, alg_obj): """ @@ -85,6 +85,7 @@ class PyJWT(object): # Header header = {'typ': 'JWT', 'alg': algorithm} + if headers: header.update(headers) @@ -128,7 +129,7 @@ class PyJWT(object): payload, signing_input, header, signature = self._load(jwt) if verify: - merged_options = self._merge_options(override_options=options) + merged_options = merge_dict(self.options, options) if merged_options.get('verify_signature'): self._verify_signature(payload, signing_input, header, signature, key, algorithms) @@ -251,21 +252,6 @@ class PyJWT(object): if payload.get('iss') != issuer: raise InvalidIssuerError('Invalid issuer') - def _merge_options(self, default_options=None, override_options=None): - if not default_options: - default_options = {} - - if not override_options: - override_options = {} - - try: - merged_options = self.default_options.copy() - merged_options.update(override_options) - except (AttributeError, ValueError) as e: - raise TypeError('options must be a dictionary: %s' % e) - - return merged_options - _jwt_global_obj = PyJWT() encode = _jwt_global_obj.encode diff --git a/jwt/utils.py b/jwt/utils.py index e6c1ef3..bb7b3a3 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -12,3 +12,16 @@ def base64url_decode(input): def base64url_encode(input): return base64.urlsafe_b64encode(input).replace(b'=', b'') + + +def merge_dict(original, updates): + if not updates: + return original + + try: + merged_options = original.copy() + merged_options.update(updates) + except (AttributeError, ValueError) as e: + raise TypeError('original and updates must be a dictionary: %s' % e) + + return merged_options diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 8eb0eb3..5eb24bc 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -113,9 +113,9 @@ class TestAlgorithms(unittest.TestCase): def test_rsa_verify_should_return_false_if_signature_invalid(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) - jwt_message = ensure_bytes('Hello World!') + message = ensure_bytes('Hello World!') - jwt_sig = base64.b64decode(ensure_bytes( + sig = base64.b64decode(ensure_bytes( 'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp' '10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl' '2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix' @@ -123,21 +123,21 @@ class TestAlgorithms(unittest.TestCase): 'fHJnNUzAEUOXS0WahHVb57D30pcgIji9z923q90p5c7E2cU8V+E1qe8NdCA' 'APCDzZZ9zQ/dgcMVaBrGrgimrcLbPjueOKFgSO+SSjIElKA==')) - jwt_sig += ensure_bytes('123') # Signature is now invalid + sig += ensure_bytes('123') # Signature is now invalid with open(key_path('testkey_rsa.pub'), 'r') as keyfile: - jwt_pub_key = algo.prepare_key(keyfile.read()) + pub_key = algo.prepare_key(keyfile.read()) - result = algo.verify(jwt_message, jwt_pub_key, jwt_sig) + result = algo.verify(message, pub_key, sig) self.assertFalse(result) @unittest.skipIf(not has_crypto, 'Not supported without cryptography library') def test_rsa_verify_should_return_true_if_signature_valid(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) - jwt_message = ensure_bytes('Hello World!') + message = ensure_bytes('Hello World!') - jwt_sig = base64.b64decode(ensure_bytes( + sig = base64.b64decode(ensure_bytes( 'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp' '10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl' '2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix' @@ -146,9 +146,9 @@ class TestAlgorithms(unittest.TestCase): 'APCDzZZ9zQ/dgcMVaBrGrgimrcLbPjueOKFgSO+SSjIElKA==')) with open(key_path('testkey_rsa.pub'), 'r') as keyfile: - jwt_pub_key = algo.prepare_key(keyfile.read()) + pub_key = algo.prepare_key(keyfile.read()) - result = algo.verify(jwt_message, jwt_pub_key, jwt_sig) + result = algo.verify(message, pub_key, sig) self.assertTrue(result) @unittest.skipIf(not has_crypto, 'Not supported without cryptography library') diff --git a/tests/test_api.py b/tests/test_api.py index 13aa982..a45107c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -14,8 +14,9 @@ from jwt.exceptions import ( InvalidAlgorithmError, InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError ) +from jwt.utils import base64url_decode -from .compat import text_type, unittest +from .compat import string_types, text_type, unittest from .utils import ensure_bytes try: @@ -80,19 +81,16 @@ class TestAPI(unittest.TestCase): self.assertNotIn('none', self.jwt.get_algorithms()) self.assertIn('HS256', self.jwt.get_algorithms()) - def test_default_options(self): - self.assertEqual(self.jwt.default_options, self.jwt.options) - def test_override_options(self): self.jwt = PyJWT(options={'verify_exp': False, 'verify_nbf': False}) - expected_options = self.jwt.default_options + expected_options = self.jwt.options expected_options['verify_exp'] = False expected_options['verify_nbf'] = False self.assertEqual(expected_options, self.jwt.options) - def test_non_default_options_persist(self): + def test_non_object_options_persist(self): self.jwt = PyJWT(options={'verify_iat': False, 'foobar': False}) - expected_options = self.jwt.default_options + expected_options = self.jwt.options expected_options['verify_iat'] = False expected_options['foobar'] = False self.assertEqual(expected_options, self.jwt.options) @@ -880,6 +878,24 @@ class TestAPI(unittest.TestCase): payload = self.jwt.decode(token, 'secret') self.assertEqual(payload, {'some_decimal': 'it worked'}) + def test_encode_headers_parameter_adds_headers(self): + headers = {'testheader': True} + token = self.jwt.encode({'msg': 'hello world'}, 'secret', headers=headers) + + if not isinstance(token, string_types): + token = token.decode() + + header = token[0:token.index('.')].encode() + header = base64url_decode(header) + + if not isinstance(header, text_type): + header = header.decode() + + header_obj = json.loads(header) + + self.assertIn('testheader', header_obj) + self.assertEqual(header_obj['testheader'], headers['testheader']) + if __name__ == '__main__': unittest.main() |