summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosé Padilla <jpadilla@webapplicate.com>2015-04-14 10:05:56 -0400
committerJosé Padilla <jpadilla@webapplicate.com>2015-04-14 10:05:56 -0400
commitb39b9a7887c2feab1058fa371f761e1e27f6da1d (patch)
treeca8babf9734bb3deaa90c206655f504279c6c7bb
parent5fd54bec162a25ae9dc0de4c476dc7c51bc1017b (diff)
parent90577f7cef93c4d9e8b6168212f29129fa2b3f71 (diff)
downloadpyjwt-b39b9a7887c2feab1058fa371f761e1e27f6da1d.tar.gz
Merge pull request #135 from mark-adams/minor-updates
Minor refactorings to make things a little cleaner
-rw-r--r--jwt/api.py24
-rw-r--r--jwt/utils.py13
-rw-r--r--tests/test_algorithms.py18
-rw-r--r--tests/test_api.py30
4 files changed, 50 insertions, 35 deletions
diff --git a/jwt/api.py b/jwt/api.py
index cca02d9..273799d 100644
--- a/jwt/api.py
+++ b/jwt/api.py
@@ -13,7 +13,7 @@ from .exceptions import (
InvalidAlgorithmError, InvalidAudienceError, InvalidIssuedAtError,
InvalidIssuerError
)
-from .utils import base64url_decode, base64url_encode
+from .utils import base64url_decode, base64url_encode, merge_dict
class PyJWT(object):
@@ -29,7 +29,7 @@ class PyJWT(object):
if not options:
options = {}
- self.default_options = {
+ default_options = {
'verify_signature': True,
'verify_exp': True,
'verify_nbf': True,
@@ -37,7 +37,7 @@ class PyJWT(object):
'verify_aud': True,
}
- self.options = self._merge_options(self.default_options, options)
+ self.options = merge_dict(default_options, options)
def register_algorithm(self, alg_id, alg_obj):
"""
@@ -85,6 +85,7 @@ class PyJWT(object):
# Header
header = {'typ': 'JWT', 'alg': algorithm}
+
if headers:
header.update(headers)
@@ -128,7 +129,7 @@ class PyJWT(object):
payload, signing_input, header, signature = self._load(jwt)
if verify:
- merged_options = self._merge_options(override_options=options)
+ merged_options = merge_dict(self.options, options)
if merged_options.get('verify_signature'):
self._verify_signature(payload, signing_input, header, signature,
key, algorithms)
@@ -251,21 +252,6 @@ class PyJWT(object):
if payload.get('iss') != issuer:
raise InvalidIssuerError('Invalid issuer')
- def _merge_options(self, default_options=None, override_options=None):
- if not default_options:
- default_options = {}
-
- if not override_options:
- override_options = {}
-
- try:
- merged_options = self.default_options.copy()
- merged_options.update(override_options)
- except (AttributeError, ValueError) as e:
- raise TypeError('options must be a dictionary: %s' % e)
-
- return merged_options
-
_jwt_global_obj = PyJWT()
encode = _jwt_global_obj.encode
diff --git a/jwt/utils.py b/jwt/utils.py
index e6c1ef3..bb7b3a3 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -12,3 +12,16 @@ def base64url_decode(input):
def base64url_encode(input):
return base64.urlsafe_b64encode(input).replace(b'=', b'')
+
+
+def merge_dict(original, updates):
+ if not updates:
+ return original
+
+ try:
+ merged_options = original.copy()
+ merged_options.update(updates)
+ except (AttributeError, ValueError) as e:
+ raise TypeError('original and updates must be a dictionary: %s' % e)
+
+ return merged_options
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index 8eb0eb3..5eb24bc 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -113,9 +113,9 @@ class TestAlgorithms(unittest.TestCase):
def test_rsa_verify_should_return_false_if_signature_invalid(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
- jwt_message = ensure_bytes('Hello World!')
+ message = ensure_bytes('Hello World!')
- jwt_sig = base64.b64decode(ensure_bytes(
+ sig = base64.b64decode(ensure_bytes(
'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp'
'10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl'
'2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix'
@@ -123,21 +123,21 @@ class TestAlgorithms(unittest.TestCase):
'fHJnNUzAEUOXS0WahHVb57D30pcgIji9z923q90p5c7E2cU8V+E1qe8NdCA'
'APCDzZZ9zQ/dgcMVaBrGrgimrcLbPjueOKFgSO+SSjIElKA=='))
- jwt_sig += ensure_bytes('123') # Signature is now invalid
+ sig += ensure_bytes('123') # Signature is now invalid
with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
- jwt_pub_key = algo.prepare_key(keyfile.read())
+ pub_key = algo.prepare_key(keyfile.read())
- result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
+ result = algo.verify(message, pub_key, sig)
self.assertFalse(result)
@unittest.skipIf(not has_crypto, 'Not supported without cryptography library')
def test_rsa_verify_should_return_true_if_signature_valid(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
- jwt_message = ensure_bytes('Hello World!')
+ message = ensure_bytes('Hello World!')
- jwt_sig = base64.b64decode(ensure_bytes(
+ sig = base64.b64decode(ensure_bytes(
'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp'
'10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl'
'2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix'
@@ -146,9 +146,9 @@ class TestAlgorithms(unittest.TestCase):
'APCDzZZ9zQ/dgcMVaBrGrgimrcLbPjueOKFgSO+SSjIElKA=='))
with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
- jwt_pub_key = algo.prepare_key(keyfile.read())
+ pub_key = algo.prepare_key(keyfile.read())
- result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
+ result = algo.verify(message, pub_key, sig)
self.assertTrue(result)
@unittest.skipIf(not has_crypto, 'Not supported without cryptography library')
diff --git a/tests/test_api.py b/tests/test_api.py
index 13aa982..a45107c 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -14,8 +14,9 @@ from jwt.exceptions import (
InvalidAlgorithmError, InvalidAudienceError, InvalidIssuedAtError,
InvalidIssuerError
)
+from jwt.utils import base64url_decode
-from .compat import text_type, unittest
+from .compat import string_types, text_type, unittest
from .utils import ensure_bytes
try:
@@ -80,19 +81,16 @@ class TestAPI(unittest.TestCase):
self.assertNotIn('none', self.jwt.get_algorithms())
self.assertIn('HS256', self.jwt.get_algorithms())
- def test_default_options(self):
- self.assertEqual(self.jwt.default_options, self.jwt.options)
-
def test_override_options(self):
self.jwt = PyJWT(options={'verify_exp': False, 'verify_nbf': False})
- expected_options = self.jwt.default_options
+ expected_options = self.jwt.options
expected_options['verify_exp'] = False
expected_options['verify_nbf'] = False
self.assertEqual(expected_options, self.jwt.options)
- def test_non_default_options_persist(self):
+ def test_non_object_options_persist(self):
self.jwt = PyJWT(options={'verify_iat': False, 'foobar': False})
- expected_options = self.jwt.default_options
+ expected_options = self.jwt.options
expected_options['verify_iat'] = False
expected_options['foobar'] = False
self.assertEqual(expected_options, self.jwt.options)
@@ -880,6 +878,24 @@ class TestAPI(unittest.TestCase):
payload = self.jwt.decode(token, 'secret')
self.assertEqual(payload, {'some_decimal': 'it worked'})
+ def test_encode_headers_parameter_adds_headers(self):
+ headers = {'testheader': True}
+ token = self.jwt.encode({'msg': 'hello world'}, 'secret', headers=headers)
+
+ if not isinstance(token, string_types):
+ token = token.decode()
+
+ header = token[0:token.index('.')].encode()
+ header = base64url_decode(header)
+
+ if not isinstance(header, text_type):
+ header = header.decode()
+
+ header_obj = json.loads(header)
+
+ self.assertIn('testheader', header_obj)
+ self.assertEqual(header_obj['testheader'], headers['testheader'])
+
if __name__ == '__main__':
unittest.main()