summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Adams <mark@markadams.me>2016-05-05 23:58:32 -0500
committerMark Adams <mark@markadams.me>2016-08-28 18:08:55 -0500
commit42b011410b28ea3e3920f40641a57c8893e2e04e (patch)
treede27ac86a51799d196e4dcfd04c5cbd069e7d22b
parentb35d522135044ba10ac41e7db5b95348cb4c4707 (diff)
downloadpyjwt-add-jwk-for-hmac-rsa.tar.gz
Add JWK support for HMAC and RSA keysadd-jwk-for-hmac-rsa
- JWKs for RSA and HMAC can be encoded / decoded using the .to_jwk() and .from_jwk() methods on their respective jwt.algorithms instances - Replaced tests.utils ensure_unicode and ensure_bytes with jwt.utils versions
-rw-r--r--jwt/algorithms.py158
-rw-r--r--jwt/api_jws.py14
-rw-r--r--jwt/compat.py28
-rw-r--r--jwt/utils.py46
-rw-r--r--tests/contrib/test_algorithms.py36
-rw-r--r--tests/keys/__init__.py30
-rw-r--r--tests/test_algorithms.py270
-rw-r--r--tests/test_api_jws.py37
-rw-r--r--tests/test_compat.py9
-rw-r--r--tests/test_utils.py40
-rw-r--r--tests/utils.py16
11 files changed, 554 insertions, 130 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 51e8f16..3144f23 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,9 +1,15 @@
import hashlib
import hmac
+import json
-from .compat import binary_type, constant_time_compare, is_string_type
+
+from .compat import constant_time_compare, string_types
from .exceptions import InvalidKeyError
-from .utils import der_to_raw_signature, raw_to_der_signature
+from .utils import (
+ base64url_decode, base64url_encode, der_to_raw_signature,
+ force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
+ to_base64url_uint
+)
try:
from cryptography.hazmat.primitives import hashes
@@ -11,7 +17,8 @@ try:
load_pem_private_key, load_pem_public_key, load_ssh_public_key
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
- RSAPrivateKey, RSAPublicKey
+ RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
+ rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKey, EllipticCurvePublicKey
@@ -77,6 +84,20 @@ class Algorithm(object):
"""
raise NotImplementedError
+ @staticmethod
+ def to_jwk(key_obj):
+ """
+ Serializes a given RSA key into a JWK
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def from_jwk(jwk):
+ """
+ Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
+ """
+ raise NotImplementedError
+
class NoneAlgorithm(Algorithm):
"""
@@ -112,11 +133,7 @@ class HMACAlgorithm(Algorithm):
self.hash_alg = hash_alg
def prepare_key(self, key):
- if not is_string_type(key):
- raise TypeError('Expecting a string- or bytes-formatted key.')
-
- if not isinstance(key, binary_type):
- key = key.encode('utf-8')
+ key = force_bytes(key)
invalid_strings = [
b'-----BEGIN PUBLIC KEY-----',
@@ -131,6 +148,22 @@ class HMACAlgorithm(Algorithm):
return key
+ @staticmethod
+ def to_jwk(key_obj):
+ return json.dumps({
+ 'k': force_unicode(base64url_encode(force_bytes(key_obj))),
+ 'kty': 'oct'
+ })
+
+ @staticmethod
+ def from_jwk(jwk):
+ obj = json.loads(jwk)
+
+ if obj.get('kty') != 'oct':
+ raise InvalidKeyError('Not an HMAC key')
+
+ return base64url_decode(obj['k'])
+
def sign(self, msg, key):
return hmac.new(key, msg, self.hash_alg).digest()
@@ -156,9 +189,8 @@ if has_crypto:
isinstance(key, RSAPublicKey):
return key
- if is_string_type(key):
- if not isinstance(key, binary_type):
- key = key.encode('utf-8')
+ if isinstance(key, string_types):
+ key = force_bytes(key)
try:
if key.startswith(b'ssh-rsa'):
@@ -172,6 +204,105 @@ if has_crypto:
return key
+ @staticmethod
+ def to_jwk(key_obj):
+ obj = None
+
+ if getattr(key_obj, 'private_numbers', None):
+ # Private key
+ numbers = key_obj.private_numbers()
+
+ obj = {
+ 'kty': 'RSA',
+ 'key_ops': ['sign'],
+ 'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)),
+ 'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)),
+ 'd': force_unicode(to_base64url_uint(numbers.d)),
+ 'p': force_unicode(to_base64url_uint(numbers.p)),
+ 'q': force_unicode(to_base64url_uint(numbers.q)),
+ 'dp': force_unicode(to_base64url_uint(numbers.dmp1)),
+ 'dq': force_unicode(to_base64url_uint(numbers.dmq1)),
+ 'qi': force_unicode(to_base64url_uint(numbers.iqmp))
+ }
+
+ elif getattr(key_obj, 'verifier', None):
+ # Public key
+ numbers = key_obj.public_numbers()
+
+ obj = {
+ 'kty': 'RSA',
+ 'key_ops': ['verify'],
+ 'n': force_unicode(to_base64url_uint(numbers.n)),
+ 'e': force_unicode(to_base64url_uint(numbers.e))
+ }
+ else:
+ raise InvalidKeyError('Not a public or private key')
+
+ return json.dumps(obj)
+
+ @staticmethod
+ def from_jwk(jwk):
+ try:
+ obj = json.loads(jwk)
+ except ValueError:
+ raise InvalidKeyError('Key is not valid JSON')
+
+ if obj.get('kty') != 'RSA':
+ raise InvalidKeyError('Not an RSA key')
+
+ if 'd' in obj and 'e' in obj and 'n' in obj:
+ # Private key
+ if 'oth' in obj:
+ raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
+
+ other_props = ['p', 'q', 'dp', 'dq', 'qi']
+ props_found = [prop in obj for prop in other_props]
+ any_props_found = any(props_found)
+
+ if any_props_found and not all(props_found):
+ raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
+
+ public_numbers = RSAPublicNumbers(
+ from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
+ )
+
+ if any_props_found:
+ numbers = RSAPrivateNumbers(
+ d=from_base64url_uint(obj['d']),
+ p=from_base64url_uint(obj['p']),
+ q=from_base64url_uint(obj['q']),
+ dmp1=from_base64url_uint(obj['dp']),
+ dmq1=from_base64url_uint(obj['dq']),
+ iqmp=from_base64url_uint(obj['qi']),
+ public_numbers=public_numbers
+ )
+ else:
+ d = from_base64url_uint(obj['d'])
+ p, q = rsa_recover_prime_factors(
+ public_numbers.n, d, public_numbers.e
+ )
+
+ numbers = RSAPrivateNumbers(
+ d=d,
+ p=p,
+ q=q,
+ dmp1=rsa_crt_dmp1(d, p),
+ dmq1=rsa_crt_dmq1(d, q),
+ iqmp=rsa_crt_iqmp(p, q),
+ public_numbers=public_numbers
+ )
+
+ return numbers.private_key(default_backend())
+ elif 'n' in obj and 'e' in obj:
+ # Public key
+ numbers = RSAPublicNumbers(
+ from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
+ )
+
+ return numbers.public_key(default_backend())
+ else:
+ raise InvalidKeyError('Not a public or private key')
+
def sign(self, msg, key):
signer = key.signer(
padding.PKCS1v15(),
@@ -213,9 +344,8 @@ if has_crypto:
isinstance(key, EllipticCurvePublicKey):
return key
- if is_string_type(key):
- if not isinstance(key, binary_type):
- key = key.encode('utf-8')
+ if isinstance(key, string_types):
+ key = force_bytes(key)
# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index 177f5ff..f4e58b7 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -7,7 +7,7 @@ from collections import Mapping
from .algorithms import Algorithm, get_default_algorithms # NOQA
from .compat import binary_type, string_types, text_type
from .exceptions import DecodeError, InvalidAlgorithmError, InvalidTokenError
-from .utils import base64url_decode, base64url_encode, merge_dict
+from .utils import base64url_decode, base64url_encode, force_bytes, merge_dict
class PyJWS(object):
@@ -82,11 +82,13 @@ class PyJWS(object):
self._validate_headers(headers)
header.update(headers)
- json_header = json.dumps(
- header,
- separators=(',', ':'),
- cls=json_encoder
- ).encode('utf-8')
+ json_header = force_bytes(
+ json.dumps(
+ header,
+ separators=(',', ':'),
+ cls=json_encoder
+ )
+ )
segments.append(base64url_encode(json_header))
segments.append(base64url_encode(payload))
diff --git a/jwt/compat.py b/jwt/compat.py
index dafd0c7..b928c7d 100644
--- a/jwt/compat.py
+++ b/jwt/compat.py
@@ -3,8 +3,9 @@ The `compat` module provides support for backwards compatibility with older
versions of python, and compatibility wrappers around optional packages.
"""
# flake8: noqa
-import sys
import hmac
+import struct
+import sys
PY3 = sys.version_info[0] == 3
@@ -20,10 +21,6 @@ else:
string_types = (text_type, binary_type)
-def is_string_type(val):
- return any([isinstance(val, typ) for typ in string_types])
-
-
def timedelta_total_seconds(delta):
try:
delta.total_seconds
@@ -56,3 +53,24 @@ except AttributeError:
result |= ord(x) ^ ord(y)
return result == 0
+
+# Use int.to_bytes if it exists (Python 3)
+if getattr(int, 'to_bytes', None):
+ def bytes_from_int(val):
+ remaining = val
+ byte_length = 0
+
+ while remaining != 0:
+ remaining = remaining >> 8
+ byte_length += 1
+
+ return val.to_bytes(byte_length, 'big', signed=False)
+else:
+ def bytes_from_int(val):
+ buf = []
+ while val:
+ val, remainder = divmod(val, 256)
+ buf.append(remainder)
+
+ buf.reverse()
+ return struct.pack('%sB' % len(buf), *buf)
diff --git a/jwt/utils.py b/jwt/utils.py
index 637b892..4f0a496 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -1,5 +1,8 @@
import base64
import binascii
+import struct
+
+from .compat import binary_type, bytes_from_int, text_type
try:
from cryptography.hazmat.primitives.asymmetric.utils import (
@@ -9,7 +12,28 @@ except ImportError:
pass
+def force_unicode(value):
+ if isinstance(value, binary_type):
+ return value.decode('utf-8')
+ elif isinstance(value, text_type):
+ return value
+ else:
+ raise TypeError('Expected a string value')
+
+
+def force_bytes(value):
+ if isinstance(value, text_type):
+ return value.encode('utf-8')
+ elif isinstance(value, binary_type):
+ return value
+ else:
+ raise TypeError('Expected a string value')
+
+
def base64url_decode(input):
+ if isinstance(input, text_type):
+ input = input.encode('ascii')
+
rem = len(input) % 4
if rem > 0:
@@ -22,6 +46,28 @@ def base64url_encode(input):
return base64.urlsafe_b64encode(input).replace(b'=', b'')
+def to_base64url_uint(val):
+ if val < 0:
+ raise ValueError('Must be a positive integer')
+
+ int_bytes = bytes_from_int(val)
+
+ if len(int_bytes) == 0:
+ int_bytes = b'\x00'
+
+ return base64url_encode(int_bytes)
+
+
+def from_base64url_uint(val):
+ if isinstance(val, text_type):
+ val = val.encode('ascii')
+
+ data = base64url_decode(val)
+
+ buf = struct.unpack('%sB' % len(data), data)
+ return int(''.join(["%02x" % byte for byte in buf]), 16)
+
+
def merge_dict(original, updates):
if not updates:
return original
diff --git a/tests/contrib/test_algorithms.py b/tests/contrib/test_algorithms.py
index 1fbe10d..6d5ca75 100644
--- a/tests/contrib/test_algorithms.py
+++ b/tests/contrib/test_algorithms.py
@@ -1,8 +1,10 @@
import base64
+from jwt.utils import force_bytes, force_unicode
+
import pytest
-from ..utils import ensure_bytes, ensure_unicode, key_path
+from ..utils import key_path
try:
from jwt.contrib.algorithms.pycrypto import RSAAlgorithm
@@ -29,7 +31,7 @@ class TestPycryptoAlgorithms:
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
with open(key_path('testkey_rsa'), 'r') as rsa_key:
- algo.prepare_key(ensure_unicode(rsa_key.read()))
+ algo.prepare_key(force_unicode(rsa_key.read()))
def test_rsa_should_reject_non_string_key(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
@@ -40,9 +42,9 @@ class TestPycryptoAlgorithms:
def test_rsa_sign_should_generate_correct_signature_value(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
- jwt_message = ensure_bytes('Hello World!')
+ jwt_message = force_bytes('Hello World!')
- expected_sig = base64.b64decode(ensure_bytes(
+ expected_sig = base64.b64decode(force_bytes(
'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp'
'10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl'
'2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix'
@@ -63,9 +65,9 @@ class TestPycryptoAlgorithms:
def test_rsa_verify_should_return_false_if_signature_invalid(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
- jwt_message = ensure_bytes('Hello World!')
+ jwt_message = force_bytes('Hello World!')
- jwt_sig = base64.b64decode(ensure_bytes(
+ jwt_sig = base64.b64decode(force_bytes(
'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp'
'10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl'
'2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix'
@@ -73,7 +75,7 @@ class TestPycryptoAlgorithms:
'fHJnNUzAEUOXS0WahHVb57D30pcgIji9z923q90p5c7E2cU8V+E1qe8NdCA'
'APCDzZZ9zQ/dgcMVaBrGrgimrcLbPjueOKFgSO+SSjIElKA=='))
- jwt_sig += ensure_bytes('123') # Signature is now invalid
+ jwt_sig += force_bytes('123') # Signature is now invalid
with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
jwt_pub_key = algo.prepare_key(keyfile.read())
@@ -84,9 +86,9 @@ class TestPycryptoAlgorithms:
def test_rsa_verify_should_return_true_if_signature_valid(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
- jwt_message = ensure_bytes('Hello World!')
+ jwt_message = force_bytes('Hello World!')
- jwt_sig = base64.b64decode(ensure_bytes(
+ jwt_sig = base64.b64decode(force_bytes(
'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp'
'10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl'
'2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix'
@@ -122,14 +124,14 @@ class TestEcdsaAlgorithms:
algo = ECAlgorithm(ECAlgorithm.SHA256)
with open(key_path('testkey_ec'), 'r') as ec_key:
- algo.prepare_key(ensure_unicode(ec_key.read()))
+ algo.prepare_key(force_unicode(ec_key.read()))
def test_ec_sign_should_generate_correct_signature_value(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
- jwt_message = ensure_bytes('Hello World!')
+ jwt_message = force_bytes('Hello World!')
- expected_sig = base64.b64decode(ensure_bytes(
+ expected_sig = base64.b64decode(force_bytes(
'AC+m4Jf/xI3guAC6w0w37t5zRpSCF6F4udEz5LiMiTIjCS4vcVe6dDOxK+M'
'mvkF8PxJuvqxP2CO3TR3okDPCl/NjATTO1jE+qBZ966CRQSSzcCM+tzcHzw'
'LZS5kbvKu0Acd/K6Ol2/W3B1NeV5F/gjvZn/jOwaLgWEUYsg0o4XVrAg65'))
@@ -147,14 +149,14 @@ class TestEcdsaAlgorithms:
def test_ec_verify_should_return_false_if_signature_invalid(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
- jwt_message = ensure_bytes('Hello World!')
+ jwt_message = force_bytes('Hello World!')
- jwt_sig = base64.b64decode(ensure_bytes(
+ jwt_sig = base64.b64decode(force_bytes(
'AC+m4Jf/xI3guAC6w0w37t5zRpSCF6F4udEz5LiMiTIjCS4vcVe6dDOxK+M'
'mvkF8PxJuvqxP2CO3TR3okDPCl/NjATTO1jE+qBZ966CRQSSzcCM+tzcHzw'
'LZS5kbvKu0Acd/K6Ol2/W3B1NeV5F/gjvZn/jOwaLgWEUYsg0o4XVrAg65'))
- jwt_sig += ensure_bytes('123') # Signature is now invalid
+ jwt_sig += force_bytes('123') # Signature is now invalid
with open(key_path('testkey_ec.pub'), 'r') as keyfile:
jwt_pub_key = algo.prepare_key(keyfile.read())
@@ -165,9 +167,9 @@ class TestEcdsaAlgorithms:
def test_ec_verify_should_return_true_if_signature_valid(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
- jwt_message = ensure_bytes('Hello World!')
+ jwt_message = force_bytes('Hello World!')
- jwt_sig = base64.b64decode(ensure_bytes(
+ jwt_sig = base64.b64decode(force_bytes(
'AC+m4Jf/xI3guAC6w0w37t5zRpSCF6F4udEz5LiMiTIjCS4vcVe6dDOxK+M'
'mvkF8PxJuvqxP2CO3TR3okDPCl/NjATTO1jE+qBZ966CRQSSzcCM+tzcHzw'
'LZS5kbvKu0Acd/K6Ol2/W3B1NeV5F/gjvZn/jOwaLgWEUYsg0o4XVrAg65'))
diff --git a/tests/keys/__init__.py b/tests/keys/__init__.py
index fad09f5..4fae687 100644
--- a/tests/keys/__init__.py
+++ b/tests/keys/__init__.py
@@ -1,15 +1,15 @@
import json
import os
-from jwt.utils import base64url_decode
+from jwt.utils import base64url_decode, force_bytes
-from tests.utils import ensure_bytes, int_from_bytes
+from tests.utils import int_from_bytes
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
def decode_value(val):
- decoded = base64url_decode(ensure_bytes(val))
+ decoded = base64url_decode(force_bytes(val))
return int_from_bytes(decoded, 'big')
@@ -17,13 +17,12 @@ def load_hmac_key():
with open(os.path.join(BASE_PATH, 'jwk_hmac.json'), 'r') as infile:
keyobj = json.load(infile)
- return base64url_decode(ensure_bytes(keyobj['k']))
+ return base64url_decode(force_bytes(keyobj['k']))
try:
- from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.backends import default_backend
-
+ from jwt.algorithms import RSAAlgorithm
has_crypto = True
except ImportError:
has_crypto = False
@@ -31,26 +30,11 @@ except ImportError:
if has_crypto:
def load_rsa_key():
with open(os.path.join(BASE_PATH, 'jwk_rsa_key.json'), 'r') as infile:
- keyobj = json.load(infile)
-
- return rsa.RSAPrivateNumbers(
- p=decode_value(keyobj['p']),
- q=decode_value(keyobj['q']),
- d=decode_value(keyobj['d']),
- dmp1=decode_value(keyobj['dp']),
- dmq1=decode_value(keyobj['dq']),
- iqmp=decode_value(keyobj['qi']),
- public_numbers=load_rsa_pub_key().public_numbers()
- ).private_key(default_backend())
+ return RSAAlgorithm.from_jwk(infile.read())
def load_rsa_pub_key():
with open(os.path.join(BASE_PATH, 'jwk_rsa_pub.json'), 'r') as infile:
- keyobj = json.load(infile)
-
- return rsa.RSAPublicNumbers(
- n=decode_value(keyobj['n']),
- e=decode_value(keyobj['e'])
- ).public_key(default_backend())
+ return RSAAlgorithm.from_jwk(infile.read())
def load_ec_key():
with open(os.path.join(BASE_PATH, 'jwk_ec_key.json'), 'r') as infile:
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index e3cf1d0..97fdc22 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -1,13 +1,14 @@
import base64
+import json
from jwt.algorithms import Algorithm, HMACAlgorithm, NoneAlgorithm
from jwt.exceptions import InvalidKeyError
-from jwt.utils import base64url_decode
+from jwt.utils import base64url_decode, force_bytes, force_unicode
import pytest
from .keys import load_hmac_key
-from .utils import ensure_bytes, ensure_unicode, key_path
+from .utils import key_path
try:
from jwt.algorithms import RSAAlgorithm, ECAlgorithm, RSAPSSAlgorithm
@@ -36,6 +37,18 @@ class TestAlgorithms:
with pytest.raises(NotImplementedError):
algo.verify('message', 'key', 'signature')
+ def test_algorithm_should_throw_exception_if_to_jwk_not_impl(self):
+ algo = Algorithm()
+
+ with pytest.raises(NotImplementedError):
+ algo.from_jwk('value')
+
+ def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self):
+ algo = Algorithm()
+
+ with pytest.raises(NotImplementedError):
+ algo.to_jwk('value')
+
def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
algo = NoneAlgorithm()
@@ -49,12 +62,12 @@ class TestAlgorithms:
algo.prepare_key(object())
exception = context.value
- assert str(exception) == 'Expecting a string- or bytes-formatted key.'
+ assert str(exception) == 'Expected a string value'
def test_hmac_should_accept_unicode_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
- algo.prepare_key(ensure_unicode('awesome'))
+ algo.prepare_key(force_unicode('awesome'))
def test_hmac_should_throw_exception_if_key_is_pem_public_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
@@ -84,6 +97,28 @@ class TestAlgorithms:
with open(key_path('testkey2_rsa.pub.pem'), 'r') as keyfile:
algo.prepare_key(keyfile.read())
+ def test_hmac_jwk_should_parse_and_verify(self):
+ algo = HMACAlgorithm(HMACAlgorithm.SHA256)
+
+ with open(key_path('jwk_hmac.json'), 'r') as keyfile:
+ key = algo.from_jwk(keyfile.read())
+
+ signature = algo.sign(b'Hello World!', key)
+ assert algo.verify(b'Hello World!', key, signature)
+
+ def test_hmac_to_jwk_returns_correct_values(self):
+ algo = HMACAlgorithm(HMACAlgorithm.SHA256)
+ key = algo.to_jwk('secret')
+
+ assert json.loads(key) == {'kty': 'oct', 'k': 'c2VjcmV0'}
+
+ def test_hmac_from_jwk_should_raise_exception_if_not_hmac_key(self):
+ algo = HMACAlgorithm(HMACAlgorithm.SHA256)
+
+ with open(key_path('jwk_rsa_pub.json'), 'r') as keyfile:
+ with pytest.raises(InvalidKeyError):
+ algo.from_jwk(keyfile.read())
+
@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
def test_rsa_should_parse_pem_public_key(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
@@ -103,7 +138,7 @@ class TestAlgorithms:
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
with open(key_path('testkey_rsa'), 'r') as rsa_key:
- algo.prepare_key(ensure_unicode(rsa_key.read()))
+ algo.prepare_key(force_unicode(rsa_key.read()))
@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
def test_rsa_should_reject_non_string_key(self):
@@ -116,9 +151,9 @@ class TestAlgorithms:
def test_rsa_verify_should_return_false_if_signature_invalid(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
- message = ensure_bytes('Hello World!')
+ message = force_bytes('Hello World!')
- sig = base64.b64decode(ensure_bytes(
+ sig = base64.b64decode(force_bytes(
'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp'
'10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl'
'2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix'
@@ -126,7 +161,7 @@ class TestAlgorithms:
'fHJnNUzAEUOXS0WahHVb57D30pcgIji9z923q90p5c7E2cU8V+E1qe8NdCA'
'APCDzZZ9zQ/dgcMVaBrGrgimrcLbPjueOKFgSO+SSjIElKA=='))
- sig += ensure_bytes('123') # Signature is now invalid
+ sig += force_bytes('123') # Signature is now invalid
with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
pub_key = algo.prepare_key(keyfile.read())
@@ -135,6 +170,191 @@ class TestAlgorithms:
assert not result
@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_jwk_public_and_private_keys_should_parse_and_verify(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with open(key_path('jwk_rsa_pub.json'), 'r') as keyfile:
+ pub_key = algo.from_jwk(keyfile.read())
+
+ with open(key_path('jwk_rsa_key.json'), 'r') as keyfile:
+ priv_key = algo.from_jwk(keyfile.read())
+
+ signature = algo.sign(force_bytes('Hello World!'), priv_key)
+ assert algo.verify(force_bytes('Hello World!'), pub_key, signature)
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_private_key_to_jwk_works_with_from_jwk(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with open(key_path('testkey_rsa'), 'r') as rsa_key:
+ orig_key = algo.prepare_key(force_unicode(rsa_key.read()))
+
+ parsed_key = algo.from_jwk(algo.to_jwk(orig_key))
+ assert parsed_key.private_numbers() == orig_key.private_numbers()
+ assert parsed_key.private_numbers().public_numbers == orig_key.private_numbers().public_numbers
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_public_key_to_jwk_works_with_from_jwk(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with open(key_path('testkey_rsa.pub'), 'r') as rsa_key:
+ orig_key = algo.prepare_key(force_unicode(rsa_key.read()))
+
+ parsed_key = algo.from_jwk(algo.to_jwk(orig_key))
+ assert parsed_key.public_numbers() == orig_key.public_numbers()
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_jwk_private_key_with_other_primes_is_invalid(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with open(key_path('jwk_rsa_key.json'), 'r') as keyfile:
+ with pytest.raises(InvalidKeyError):
+ keydata = json.loads(keyfile.read())
+ keydata['oth'] = []
+
+ algo.from_jwk(json.dumps(keydata))
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_jwk_private_key_with_missing_values_is_invalid(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with open(key_path('jwk_rsa_key.json'), 'r') as keyfile:
+ with pytest.raises(InvalidKeyError):
+ keydata = json.loads(keyfile.read())
+ del keydata['p']
+
+ algo.from_jwk(json.dumps(keydata))
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_jwk_private_key_can_recover_prime_factors(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with open(key_path('jwk_rsa_key.json'), 'r') as keyfile:
+ keybytes = keyfile.read()
+ control_key = algo.from_jwk(keybytes).private_numbers()
+
+ keydata = json.loads(keybytes)
+ delete_these = ['p', 'q', 'dp', 'dq', 'qi']
+ for field in delete_these:
+ del keydata[field]
+
+ parsed_key = algo.from_jwk(json.dumps(keydata)).private_numbers()
+
+ assert control_key.d == parsed_key.d
+ assert control_key.p == parsed_key.p
+ assert control_key.q == parsed_key.q
+ assert control_key.dmp1 == parsed_key.dmp1
+ assert control_key.dmq1 == parsed_key.dmq1
+ assert control_key.iqmp == parsed_key.iqmp
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_jwk_private_key_with_missing_required_values_is_invalid(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with open(key_path('jwk_rsa_key.json'), 'r') as keyfile:
+ with pytest.raises(InvalidKeyError):
+ keydata = json.loads(keyfile.read())
+ del keydata['p']
+
+ algo.from_jwk(json.dumps(keydata))
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_jwk_raises_exception_if_not_a_valid_key(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ # Invalid JSON
+ with pytest.raises(InvalidKeyError):
+ algo.from_jwk('{not-a-real-key')
+
+ # Missing key parts
+ with pytest.raises(InvalidKeyError):
+ algo.from_jwk('{"kty": "RSA"}')
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_to_jwk_returns_correct_values_for_public_key(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
+ pub_key = algo.prepare_key(keyfile.read())
+
+ key = algo.to_jwk(pub_key)
+
+ expected = {
+ 'e': 'AQAB',
+ 'key_ops': ['verify'],
+ 'kty': 'RSA',
+ 'n': (
+ '1HgzBfJv2cOjQryCwe8NEelriOTNFWKZUivevUrRhlqcmZJdCvuCJRr-xCN-'
+ 'OmO8qwgJJR98feNujxVg-J9Ls3_UOA4HcF9nYH6aqVXELAE8Hk_ALvxi96ms'
+ '1DDuAvQGaYZ-lANxlvxeQFOZSbjkz_9mh8aLeGKwqJLp3p-OhUBQpwvAUAPg'
+ '82-OUtgTW3nSljjeFr14B8qAneGSc_wl0ni--1SRZUXFSovzcqQOkla3W27r'
+ 'rLfrD6LXgj_TsDs4vD1PnIm1zcVenKT7TfYI17bsG_O_Wecwz2Nl19pL7gDo'
+ 'sNruF3ogJWNq1Lyn_ijPQnkPLpZHyhvuiycYcI3DiQ'
+ ),
+ }
+ assert json.loads(key) == expected
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_to_jwk_returns_correct_values_for_private_key(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with open(key_path('testkey_rsa'), 'r') as keyfile:
+ priv_key = algo.prepare_key(keyfile.read())
+
+ key = algo.to_jwk(priv_key)
+
+ expected = {
+ 'key_ops': [u'sign'],
+ 'kty': 'RSA',
+ 'e': 'AQAB',
+ 'n': (
+ '1HgzBfJv2cOjQryCwe8NEelriOTNFWKZUivevUrRhlqcmZJdCvuCJRr-xCN-'
+ 'OmO8qwgJJR98feNujxVg-J9Ls3_UOA4HcF9nYH6aqVXELAE8Hk_ALvxi96ms'
+ '1DDuAvQGaYZ-lANxlvxeQFOZSbjkz_9mh8aLeGKwqJLp3p-OhUBQpwvAUAPg'
+ '82-OUtgTW3nSljjeFr14B8qAneGSc_wl0ni--1SRZUXFSovzcqQOkla3W27r'
+ 'rLfrD6LXgj_TsDs4vD1PnIm1zcVenKT7TfYI17bsG_O_Wecwz2Nl19pL7gDo'
+ 'sNruF3ogJWNq1Lyn_ijPQnkPLpZHyhvuiycYcI3DiQ'
+ ),
+ 'd': ('rfbs8AWdB1RkLJRlC51LukrAvYl5UfU1TE6XRa4o-DTg2-03OXLNEMyVpMr'
+ 'a47weEnu14StypzC8qXL7vxXOyd30SSFTffLfleaTg-qxgMZSDw-Fb_M-pU'
+ 'HMPMEDYG-lgGma4l4fd1yTX2ATtoUo9BVOQgWS1LMZqi0ASEOkUfzlBgL04'
+ 'UoaLhPSuDdLygdlDzgruVPnec0t1uOEObmrcWIkhwU2CGQzeLtuzX6OVgPh'
+ 'k7xcnjbDurTTVpWH0R0gbZ5ukmQ2P-YuCX8T9iWNMGjPNSkb7h02s2Oe9ZR'
+ 'zP007xQ0VF-Z7xyLuxk6ASmoX1S39ujSbk2WF0eXNPRgFwQ'),
+ 'q': ('47hlW2f1ARuWYJf9Dl6MieXjdj2dGx9PL2UH0unVzJYInd56nqXNPrQrc5k'
+ 'ZU65KApC9n9oKUwIxuqwAAbh8oGNEQDqnuTj-powCkdC6bwA8KH1Y-wotpq'
+ '_GSjxkNzjWRm2GArJSzZc6Fb8EuObOrAavKJ285-zMPCEfus1WZG0'),
+ 'p': ('7tr0z929Lp4OHIRJjIKM_rDrWMPtRgnV-51pgWsN6qdpDzns_PgFwrHcoyY'
+ 'sWIO-4yCdVWPxFOgEZ8xXTM_uwOe4VEmdZhw55Tx7axYZtmZYZbO_RIP4CG'
+ 'mlJlOFTiYnxpr-2Cx6kIeQmd-hf7fA3tL018aEzwYMbFMcnAGnEg0'),
+ 'qi': ('djo95mB0LVYikNPa-NgyDwLotLqrueb9IviMmn6zKHCwiOXReqXDX9slB8'
+ 'RA15uv56bmN04O__NyVFcgJ2ef169GZHiRFIgIy0Pl8LYkMhCYKKhyqM7g'
+ 'xN-SqGqDTKDC22j00S7jcvCaa1qadn1qbdfukZ4NXv7E2d_LO0Y2Kkc'),
+ 'dp': ('tgZ2-tJpEdWxu1m1EzeKa644LHVjpTRptk7H0LDc8i6SieADEuWQvkb9df'
+ 'fpY6tDFaQNQr3fQ6dtdAztmsP7l1b_ynwvT1nDZUcqZvl4ruBgDWFmKbjI'
+ 'lOCt0v9jX6MEPP5xqBx9axdkw18BnGtUuHrbzHSlUX-yh_rumpVH1SE'),
+ 'dq': ('xxCIuhD0YlWFbUcwFgGdBWcLIm_WCMGj7SB6aGu1VDTLr4Wu10TFWM0TNu'
+ 'hc9YPker2gpj5qzAmdAzwcfWSSvXpJTYR43jfulBTMoj8-2o3wCM0anclW'
+ 'AuKhin-kc4mh9ssDXRQZwlMymZP0QtaxUDw_nlfVrUCZgO7L1_ZsUTk')
+ }
+ assert json.loads(key) == expected
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_to_jwk_raises_exception_on_invalid_key(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with pytest.raises(InvalidKeyError):
+ algo.to_jwk({'not': 'a valid key'})
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
+ def test_rsa_from_jwk_raises_exception_on_invalid_key(self):
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+
+ with open(key_path('jwk_hmac.json'), 'r') as keyfile:
+ with pytest.raises(InvalidKeyError):
+ algo.from_jwk(keyfile.read())
+
+ @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
def test_ec_should_reject_non_string_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
@@ -146,7 +366,7 @@ class TestAlgorithms:
algo = ECAlgorithm(ECAlgorithm.SHA256)
with open(key_path('testkey_ec'), 'r') as ec_key:
- algo.prepare_key(ensure_unicode(ec_key.read()))
+ algo.prepare_key(force_unicode(ec_key.read()))
@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
def test_ec_should_accept_pem_private_key_bytes(self):
@@ -159,10 +379,10 @@ class TestAlgorithms:
def test_ec_verify_should_return_false_if_signature_invalid(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
- message = ensure_bytes('Hello World!')
+ message = force_bytes('Hello World!')
# Mess up the signature by replacing a known byte
- sig = base64.b64decode(ensure_bytes(
+ sig = base64.b64decode(force_bytes(
'AC+m4Jf/xI3guAC6w0w37t5zRpSCF6F4udEz5LiMiTIjCS4vcVe6dDOxK+M'
'mvkF8PxJuvqxP2CO3TR3okDPCl/NjATTO1jE+qBZ966CRQSSzcCM+tzcHzw'
'LZS5kbvKu0Acd/K6Ol2/W3B1NeV5F/gjvZn/jOwaLgWEUYsg0o4XVrAg65'.replace('r', 's')))
@@ -177,9 +397,9 @@ class TestAlgorithms:
def test_ec_verify_should_return_false_if_signature_wrong_length(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
- message = ensure_bytes('Hello World!')
+ message = force_bytes('Hello World!')
- sig = base64.b64decode(ensure_bytes('AC+m4Jf/xI3guAC6w0w3'))
+ sig = base64.b64decode(force_bytes('AC+m4Jf/xI3guAC6w0w3'))
with open(key_path('testkey_ec.pub'), 'r') as keyfile:
pub_key = algo.prepare_key(keyfile.read())
@@ -191,7 +411,7 @@ class TestAlgorithms:
def test_rsa_pss_sign_then_verify_should_return_true(self):
algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256)
- message = ensure_bytes('Hello World!')
+ message = force_bytes('Hello World!')
with open(key_path('testkey_rsa'), 'r') as keyfile:
priv_key = algo.prepare_key(keyfile.read())
@@ -207,9 +427,9 @@ class TestAlgorithms:
def test_rsa_pss_verify_should_return_false_if_signature_invalid(self):
algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256)
- jwt_message = ensure_bytes('Hello World!')
+ jwt_message = force_bytes('Hello World!')
- jwt_sig = base64.b64decode(ensure_bytes(
+ jwt_sig = base64.b64decode(force_bytes(
'ywKAUGRIDC//6X+tjvZA96yEtMqpOrSppCNfYI7NKyon3P7doud5v65oWNu'
'vQsz0fzPGfF7mQFGo9Cm9Vn0nljm4G6PtqZRbz5fXNQBH9k10gq34AtM02c'
'/cveqACQ8gF3zxWh6qr9jVqIpeMEaEBIkvqG954E0HT9s9ybHShgHX9mlWk'
@@ -217,7 +437,7 @@ class TestAlgorithms:
'daOCWqbpZDuLb1imKpmm8Nsm56kAxijMLZnpCcnPgyb7CqG+B93W9GHglA5'
'drUeR1gRtO7vqbZMsCAQ4bpjXxwbYyjQlEVuMl73UL6sOWg=='))
- jwt_sig += ensure_bytes('123') # Signature is now invalid
+ jwt_sig += force_bytes('123') # Signature is now invalid
with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
jwt_pub_key = algo.prepare_key(keyfile.read())
@@ -239,7 +459,7 @@ class TestAlgorithmsRFC7520:
Reference: https://tools.ietf.org/html/rfc7520#section-4.4
"""
- signing_input = ensure_bytes(
+ signing_input = force_bytes(
'eyJhbGciOiJIUzI1NiIsImtpZCI6IjAxOGMwYWU1LTRkOWItNDcxYi1iZmQ2LWVlZ'
'jMxNGJjNzAzNyJ9.SXTigJlzIGEgZGFuZ2Vyb3VzIGJ1c2luZXNzLCBGcm9kbywgZ'
'29pbmcgb3V0IHlvdXIgZG9vci4gWW91IHN0ZXAgb250byB0aGUgcm9hZCwgYW5kIG'
@@ -247,7 +467,7 @@ class TestAlgorithmsRFC7520:
'gd2hlcmUgeW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4'
)
- signature = base64url_decode(ensure_bytes(
+ signature = base64url_decode(force_bytes(
's0h6KThzkfBBBkLspW1h84VsJZFTsPPqMDA7g1Md7p0'
))
@@ -265,7 +485,7 @@ class TestAlgorithmsRFC7520:
Reference: https://tools.ietf.org/html/rfc7520#section-4.1
"""
- signing_input = ensure_bytes(
+ signing_input = force_bytes(
'eyJhbGciOiJSUzI1NiIsImtpZCI6ImJpbGJvLmJhZ2dpbnNAaG9iYml0b24uZXhhb'
'XBsZSJ9.SXTigJlzIGEgZGFuZ2Vyb3VzIGJ1c2luZXNzLCBGcm9kbywgZ29pbmcgb'
'3V0IHlvdXIgZG9vci4gWW91IHN0ZXAgb250byB0aGUgcm9hZCwgYW5kIGlmIHlvdS'
@@ -273,7 +493,7 @@ class TestAlgorithmsRFC7520:
'geW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4'
)
- signature = base64url_decode(ensure_bytes(
+ signature = base64url_decode(force_bytes(
'MRjdkly7_-oTPTS3AXP41iQIGKa80A0ZmTuV5MEaHoxnW2e5CZ5NlKtainoFmKZop'
'dHM1O2U4mwzJdQx996ivp83xuglII7PNDi84wnB-BDkoBwA78185hX-Es4JIwmDLJ'
'K3lfWRa-XtL0RnltuYv746iYTh_qHRD68BNt1uSNCrUCTJDt5aAE6x8wW1Kt9eRo4'
@@ -296,7 +516,7 @@ class TestAlgorithmsRFC7520:
Reference: https://tools.ietf.org/html/rfc7520#section-4.2
"""
- signing_input = ensure_bytes(
+ signing_input = force_bytes(
'eyJhbGciOiJQUzM4NCIsImtpZCI6ImJpbGJvLmJhZ2dpbnNAaG9iYml0b24uZXhhb'
'XBsZSJ9.SXTigJlzIGEgZGFuZ2Vyb3VzIGJ1c2luZXNzLCBGcm9kbywgZ29pbmcgb'
'3V0IHlvdXIgZG9vci4gWW91IHN0ZXAgb250byB0aGUgcm9hZCwgYW5kIGlmIHlvdS'
@@ -304,7 +524,7 @@ class TestAlgorithmsRFC7520:
'geW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4'
)
- signature = base64url_decode(ensure_bytes(
+ signature = base64url_decode(force_bytes(
'cu22eBqkYDKgIlTpzDXGvaFfz6WGoz7fUDcfT0kkOy42miAh2qyBzk1xEsnk2IpN6'
'-tPid6VrklHkqsGqDqHCdP6O8TTB5dDDItllVo6_1OLPpcbUrhiUSMxbbXUvdvWXz'
'g-UD8biiReQFlfz28zGWVsdiNAUf8ZnyPEgVFn442ZdNqiVJRmBqrYRXe8P_ijQ7p'
@@ -327,7 +547,7 @@ class TestAlgorithmsRFC7520:
Reference: https://tools.ietf.org/html/rfc7520#section-4.3
"""
- signing_input = ensure_bytes(
+ signing_input = force_bytes(
'eyJhbGciOiJFUzUxMiIsImtpZCI6ImJpbGJvLmJhZ2dpbnNAaG9iYml0b24uZXhhb'
'XBsZSJ9.SXTigJlzIGEgZGFuZ2Vyb3VzIGJ1c2luZXNzLCBGcm9kbywgZ29pbmcgb'
'3V0IHlvdXIgZG9vci4gWW91IHN0ZXAgb250byB0aGUgcm9hZCwgYW5kIGlmIHlvdS'
@@ -335,7 +555,7 @@ class TestAlgorithmsRFC7520:
'geW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4'
)
- signature = base64url_decode(ensure_bytes(
+ signature = base64url_decode(force_bytes(
'AE_R_YZCChjn4791jSQCrdPZCNYqHXCTZH0-JZGYNlaAjP2kqaluUIIUnC9qvbu9P'
'lon7KRTzoNEuT4Va2cmL1eJAQy3mtPBu_u_sDDyYjnAMDxXPn7XrT0lw-kvAD890j'
'l8e2puQens_IEKBpHABlsbEPX6sFY8OcGDqoRuBomu9xQ2'
diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py
index c56ec4b..855575c 100644
--- a/tests/test_api_jws.py
+++ b/tests/test_api_jws.py
@@ -8,12 +8,11 @@ from jwt.api_jws import PyJWS
from jwt.exceptions import (
DecodeError, InvalidAlgorithmError, InvalidTokenError
)
-from jwt.utils import base64url_decode
+from jwt.utils import base64url_decode, force_bytes, force_unicode
import pytest
from .compat import string_types, text_type
-from .utils import ensure_bytes, ensure_unicode
try:
from cryptography.hazmat.backends import default_backend
@@ -34,7 +33,7 @@ def jws():
@pytest.fixture
def payload():
""" Creates a sample jws claimset for use as a payload during tests """
- return ensure_bytes('hello world')
+ return force_bytes('hello world')
class TestJWS:
@@ -211,7 +210,7 @@ class TestJWS:
b'HAG0_zxxu0JyINOFT2iqF3URYl9HZ8kZWMeZAtXmn6Cw'
b'PXRJD2f7N-f7bJ5JeL9VT5beI2XD3FlK3GgRvI-eE-2Ik')
decoded_payload = jws.decode(example_jws, example_pubkey)
- json_payload = json.loads(ensure_unicode(decoded_payload))
+ json_payload = json.loads(force_unicode(decoded_payload))
assert json_payload == example_payload
@@ -236,7 +235,7 @@ class TestJWS:
b'uwmrtSWCBUjiN8sqJ00CDgycxKqHfUndZbEAOjcCAhBr'
b'qWW3mSVivUfubsYbwUdUG3fSRPjaUPcpe8A')
decoded_payload = jws.decode(example_jws, example_pubkey)
- json_payload = json.loads(ensure_unicode(decoded_payload))
+ json_payload = json.loads(force_unicode(decoded_payload))
assert json_payload == example_payload
@@ -410,12 +409,12 @@ class TestJWS:
def test_encode_decode_with_rsa_sha256(self, jws, payload):
# PEM-formatted RSA key
with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file:
- priv_rsakey = load_pem_private_key(ensure_bytes(rsa_priv_file.read()),
+ priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode(payload, priv_rsakey, algorithm='RS256')
with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file:
- pub_rsakey = load_ssh_public_key(ensure_bytes(rsa_pub_file.read()),
+ pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()),
backend=default_backend())
jws.decode(jws_message, pub_rsakey)
@@ -433,12 +432,12 @@ class TestJWS:
def test_encode_decode_with_rsa_sha384(self, jws, payload):
# PEM-formatted RSA key
with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file:
- priv_rsakey = load_pem_private_key(ensure_bytes(rsa_priv_file.read()),
+ priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode(payload, priv_rsakey, algorithm='RS384')
with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file:
- pub_rsakey = load_ssh_public_key(ensure_bytes(rsa_pub_file.read()),
+ pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()),
backend=default_backend())
jws.decode(jws_message, pub_rsakey)
@@ -455,12 +454,12 @@ class TestJWS:
def test_encode_decode_with_rsa_sha512(self, jws, payload):
# PEM-formatted RSA key
with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file:
- priv_rsakey = load_pem_private_key(ensure_bytes(rsa_priv_file.read()),
+ priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode(payload, priv_rsakey, algorithm='RS512')
with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file:
- pub_rsakey = load_ssh_public_key(ensure_bytes(rsa_pub_file.read()),
+ pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()),
backend=default_backend())
jws.decode(jws_message, pub_rsakey)
@@ -497,12 +496,12 @@ class TestJWS:
def test_encode_decode_with_ecdsa_sha256(self, jws, payload):
# PEM-formatted EC key
with open('tests/keys/testkey_ec', 'r') as ec_priv_file:
- priv_eckey = load_pem_private_key(ensure_bytes(ec_priv_file.read()),
+ priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode(payload, priv_eckey, algorithm='ES256')
with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file:
- pub_eckey = load_pem_public_key(ensure_bytes(ec_pub_file.read()),
+ pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()),
backend=default_backend())
jws.decode(jws_message, pub_eckey)
@@ -520,12 +519,12 @@ class TestJWS:
# PEM-formatted EC key
with open('tests/keys/testkey_ec', 'r') as ec_priv_file:
- priv_eckey = load_pem_private_key(ensure_bytes(ec_priv_file.read()),
+ priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode(payload, priv_eckey, algorithm='ES384')
with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file:
- pub_eckey = load_pem_public_key(ensure_bytes(ec_pub_file.read()),
+ pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()),
backend=default_backend())
jws.decode(jws_message, pub_eckey)
@@ -542,12 +541,12 @@ class TestJWS:
def test_encode_decode_with_ecdsa_sha512(self, jws, payload):
# PEM-formatted EC key
with open('tests/keys/testkey_ec', 'r') as ec_priv_file:
- priv_eckey = load_pem_private_key(ensure_bytes(ec_priv_file.read()),
+ priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode(payload, priv_eckey, algorithm='ES512')
with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file:
- pub_eckey = load_pem_public_key(ensure_bytes(ec_pub_file.read()), backend=default_backend())
+ pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()), backend=default_backend())
jws.decode(jws_message, pub_eckey)
# string-formatted key
@@ -606,8 +605,8 @@ class TestJWS:
token = jws.encode(payload, 'secret', headers=data,
json_encoder=CustomJSONEncoder)
- header = ensure_bytes(ensure_unicode(token).split('.')[0])
- header = json.loads(ensure_unicode(base64url_decode(header)))
+ header = force_bytes(force_unicode(token).split('.')[0])
+ header = json.loads(force_unicode(base64url_decode(header)))
assert 'some_decimal' in header
assert header['some_decimal'] == 'it worked'
diff --git a/tests/test_compat.py b/tests/test_compat.py
index 6f6d6d0..10beb94 100644
--- a/tests/test_compat.py
+++ b/tests/test_compat.py
@@ -1,20 +1,19 @@
from jwt.compat import constant_time_compare
-
-from .utils import ensure_bytes
+from jwt.utils import force_bytes
class TestCompat:
def test_constant_time_compare_returns_true_if_same(self):
assert constant_time_compare(
- ensure_bytes('abc'), ensure_bytes('abc')
+ force_bytes('abc'), force_bytes('abc')
)
def test_constant_time_compare_returns_false_if_diff_lengths(self):
assert not constant_time_compare(
- ensure_bytes('abc'), ensure_bytes('abcd')
+ force_bytes('abc'), force_bytes('abcd')
)
def test_constant_time_compare_returns_false_if_totally_different(self):
assert not constant_time_compare(
- ensure_bytes('abcd'), ensure_bytes('efgh')
+ force_bytes('abcd'), force_bytes('efgh')
)
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..7549faa
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,40 @@
+from jwt.utils import (
+ force_bytes, force_unicode, from_base64url_uint, to_base64url_uint
+)
+
+import pytest
+
+
+@pytest.mark.parametrize("inputval,expected", [
+ (0, b'AA'),
+ (1, b'AQ'),
+ (255, b'_w'),
+ (65537, b'AQAB'),
+ (123456789, b'B1vNFQ'),
+ pytest.mark.xfail((-1, ''), raises=ValueError)
+])
+def test_to_base64url_uint(inputval, expected):
+ actual = to_base64url_uint(inputval)
+ assert actual == expected
+
+
+@pytest.mark.parametrize("inputval,expected", [
+ (b'AA', 0),
+ (b'AQ', 1),
+ (b'_w', 255),
+ (b'AQAB', 65537),
+ (b'B1vNFQ', 123456789, ),
+])
+def test_from_base64url_uint(inputval, expected):
+ actual = from_base64url_uint(inputval)
+ assert actual == expected
+
+
+def test_force_unicode_raises_error_on_invalid_object():
+ with pytest.raises(TypeError):
+ force_unicode({})
+
+
+def test_force_bytes_raises_error_on_invalid_object():
+ with pytest.raises(TypeError):
+ force_bytes({})
diff --git a/tests/utils.py b/tests/utils.py
index 7722702..2e3f043 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -4,22 +4,6 @@ import struct
from calendar import timegm
from datetime import datetime
-from .compat import text_type
-
-
-def ensure_bytes(key):
- if isinstance(key, text_type):
- key = key.encode('utf-8')
-
- return key
-
-
-def ensure_unicode(key):
- if not isinstance(key, text_type):
- key = key.decode()
-
- return key
-
def utc_timestamp():
return timegm(datetime.utcnow().utctimetuple())