summaryrefslogtreecommitdiff
path: root/jwt/algorithms.py
diff options
context:
space:
mode:
authorMark Adams <mark@markadams.me>2016-05-05 23:58:32 -0500
committerMark Adams <mark@markadams.me>2016-08-28 18:08:55 -0500
commit42b011410b28ea3e3920f40641a57c8893e2e04e (patch)
treede27ac86a51799d196e4dcfd04c5cbd069e7d22b /jwt/algorithms.py
parentb35d522135044ba10ac41e7db5b95348cb4c4707 (diff)
downloadpyjwt-42b011410b28ea3e3920f40641a57c8893e2e04e.tar.gz
Add JWK support for HMAC and RSA keysadd-jwk-for-hmac-rsa
- JWKs for RSA and HMAC can be encoded / decoded using the .to_jwk() and .from_jwk() methods on their respective jwt.algorithms instances - Replaced tests.utils ensure_unicode and ensure_bytes with jwt.utils versions
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r--jwt/algorithms.py158
1 files changed, 144 insertions, 14 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 51e8f16..3144f23 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,9 +1,15 @@
import hashlib
import hmac
+import json
-from .compat import binary_type, constant_time_compare, is_string_type
+
+from .compat import constant_time_compare, string_types
from .exceptions import InvalidKeyError
-from .utils import der_to_raw_signature, raw_to_der_signature
+from .utils import (
+ base64url_decode, base64url_encode, der_to_raw_signature,
+ force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
+ to_base64url_uint
+)
try:
from cryptography.hazmat.primitives import hashes
@@ -11,7 +17,8 @@ try:
load_pem_private_key, load_pem_public_key, load_ssh_public_key
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
- RSAPrivateKey, RSAPublicKey
+ RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
+ rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKey, EllipticCurvePublicKey
@@ -77,6 +84,20 @@ class Algorithm(object):
"""
raise NotImplementedError
+ @staticmethod
+ def to_jwk(key_obj):
+ """
+ Serializes a given RSA key into a JWK
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def from_jwk(jwk):
+ """
+ Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
+ """
+ raise NotImplementedError
+
class NoneAlgorithm(Algorithm):
"""
@@ -112,11 +133,7 @@ class HMACAlgorithm(Algorithm):
self.hash_alg = hash_alg
def prepare_key(self, key):
- if not is_string_type(key):
- raise TypeError('Expecting a string- or bytes-formatted key.')
-
- if not isinstance(key, binary_type):
- key = key.encode('utf-8')
+ key = force_bytes(key)
invalid_strings = [
b'-----BEGIN PUBLIC KEY-----',
@@ -131,6 +148,22 @@ class HMACAlgorithm(Algorithm):
return key
+ @staticmethod
+ def to_jwk(key_obj):
+ return json.dumps({
+ 'k': force_unicode(base64url_encode(force_bytes(key_obj))),
+ 'kty': 'oct'
+ })
+
+ @staticmethod
+ def from_jwk(jwk):
+ obj = json.loads(jwk)
+
+ if obj.get('kty') != 'oct':
+ raise InvalidKeyError('Not an HMAC key')
+
+ return base64url_decode(obj['k'])
+
def sign(self, msg, key):
return hmac.new(key, msg, self.hash_alg).digest()
@@ -156,9 +189,8 @@ if has_crypto:
isinstance(key, RSAPublicKey):
return key
- if is_string_type(key):
- if not isinstance(key, binary_type):
- key = key.encode('utf-8')
+ if isinstance(key, string_types):
+ key = force_bytes(key)
try:
if key.startswith(b'ssh-rsa'):
@@ -172,6 +204,105 @@ if has_crypto:
return key
+ @staticmethod
+ def to_jwk(key_obj):
+ obj = None
+
+ if getattr(key_obj, 'private_numbers', None):
+ # Private key
+ numbers = key_obj.private_numbers()
+
+ obj = {
+ 'kty': 'RSA',
+ 'key_ops': ['sign'],
+ 'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)),
+ 'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)),
+ 'd': force_unicode(to_base64url_uint(numbers.d)),
+ 'p': force_unicode(to_base64url_uint(numbers.p)),
+ 'q': force_unicode(to_base64url_uint(numbers.q)),
+ 'dp': force_unicode(to_base64url_uint(numbers.dmp1)),
+ 'dq': force_unicode(to_base64url_uint(numbers.dmq1)),
+ 'qi': force_unicode(to_base64url_uint(numbers.iqmp))
+ }
+
+ elif getattr(key_obj, 'verifier', None):
+ # Public key
+ numbers = key_obj.public_numbers()
+
+ obj = {
+ 'kty': 'RSA',
+ 'key_ops': ['verify'],
+ 'n': force_unicode(to_base64url_uint(numbers.n)),
+ 'e': force_unicode(to_base64url_uint(numbers.e))
+ }
+ else:
+ raise InvalidKeyError('Not a public or private key')
+
+ return json.dumps(obj)
+
+ @staticmethod
+ def from_jwk(jwk):
+ try:
+ obj = json.loads(jwk)
+ except ValueError:
+ raise InvalidKeyError('Key is not valid JSON')
+
+ if obj.get('kty') != 'RSA':
+ raise InvalidKeyError('Not an RSA key')
+
+ if 'd' in obj and 'e' in obj and 'n' in obj:
+ # Private key
+ if 'oth' in obj:
+ raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
+
+ other_props = ['p', 'q', 'dp', 'dq', 'qi']
+ props_found = [prop in obj for prop in other_props]
+ any_props_found = any(props_found)
+
+ if any_props_found and not all(props_found):
+ raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
+
+ public_numbers = RSAPublicNumbers(
+ from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
+ )
+
+ if any_props_found:
+ numbers = RSAPrivateNumbers(
+ d=from_base64url_uint(obj['d']),
+ p=from_base64url_uint(obj['p']),
+ q=from_base64url_uint(obj['q']),
+ dmp1=from_base64url_uint(obj['dp']),
+ dmq1=from_base64url_uint(obj['dq']),
+ iqmp=from_base64url_uint(obj['qi']),
+ public_numbers=public_numbers
+ )
+ else:
+ d = from_base64url_uint(obj['d'])
+ p, q = rsa_recover_prime_factors(
+ public_numbers.n, d, public_numbers.e
+ )
+
+ numbers = RSAPrivateNumbers(
+ d=d,
+ p=p,
+ q=q,
+ dmp1=rsa_crt_dmp1(d, p),
+ dmq1=rsa_crt_dmq1(d, q),
+ iqmp=rsa_crt_iqmp(p, q),
+ public_numbers=public_numbers
+ )
+
+ return numbers.private_key(default_backend())
+ elif 'n' in obj and 'e' in obj:
+ # Public key
+ numbers = RSAPublicNumbers(
+ from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
+ )
+
+ return numbers.public_key(default_backend())
+ else:
+ raise InvalidKeyError('Not a public or private key')
+
def sign(self, msg, key):
signer = key.signer(
padding.PKCS1v15(),
@@ -213,9 +344,8 @@ if has_crypto:
isinstance(key, EllipticCurvePublicKey):
return key
- if is_string_type(key):
- if not isinstance(key, binary_type):
- key = key.encode('utf-8')
+ if isinstance(key, string_types):
+ key = force_bytes(key)
# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try