diff options
author | Mark Adams <mark@markadams.me> | 2016-05-05 23:58:32 -0500 |
---|---|---|
committer | Mark Adams <mark@markadams.me> | 2016-08-28 18:08:55 -0500 |
commit | 42b011410b28ea3e3920f40641a57c8893e2e04e (patch) | |
tree | de27ac86a51799d196e4dcfd04c5cbd069e7d22b /jwt/algorithms.py | |
parent | b35d522135044ba10ac41e7db5b95348cb4c4707 (diff) | |
download | pyjwt-42b011410b28ea3e3920f40641a57c8893e2e04e.tar.gz |
Add JWK support for HMAC and RSA keysadd-jwk-for-hmac-rsa
- JWKs for RSA and HMAC can be encoded / decoded using the .to_jwk() and
.from_jwk() methods on their respective jwt.algorithms instances
- Replaced tests.utils ensure_unicode and ensure_bytes with jwt.utils versions
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r-- | jwt/algorithms.py | 158 |
1 files changed, 144 insertions, 14 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 51e8f16..3144f23 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,9 +1,15 @@ import hashlib import hmac +import json -from .compat import binary_type, constant_time_compare, is_string_type + +from .compat import constant_time_compare, string_types from .exceptions import InvalidKeyError -from .utils import der_to_raw_signature, raw_to_der_signature +from .utils import ( + base64url_decode, base64url_encode, der_to_raw_signature, + force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature, + to_base64url_uint +) try: from cryptography.hazmat.primitives import hashes @@ -11,7 +17,8 @@ try: load_pem_private_key, load_pem_public_key, load_ssh_public_key ) from cryptography.hazmat.primitives.asymmetric.rsa import ( - RSAPrivateKey, RSAPublicKey + RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers, + rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp ) from cryptography.hazmat.primitives.asymmetric.ec import ( EllipticCurvePrivateKey, EllipticCurvePublicKey @@ -77,6 +84,20 @@ class Algorithm(object): """ raise NotImplementedError + @staticmethod + def to_jwk(key_obj): + """ + Serializes a given RSA key into a JWK + """ + raise NotImplementedError + + @staticmethod + def from_jwk(jwk): + """ + Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object + """ + raise NotImplementedError + class NoneAlgorithm(Algorithm): """ @@ -112,11 +133,7 @@ class HMACAlgorithm(Algorithm): self.hash_alg = hash_alg def prepare_key(self, key): - if not is_string_type(key): - raise TypeError('Expecting a string- or bytes-formatted key.') - - if not isinstance(key, binary_type): - key = key.encode('utf-8') + key = force_bytes(key) invalid_strings = [ b'-----BEGIN PUBLIC KEY-----', @@ -131,6 +148,22 @@ class HMACAlgorithm(Algorithm): return key + @staticmethod + def to_jwk(key_obj): + return json.dumps({ + 'k': force_unicode(base64url_encode(force_bytes(key_obj))), + 'kty': 'oct' + }) + + @staticmethod + def from_jwk(jwk): + obj = json.loads(jwk) + + if obj.get('kty') != 'oct': + raise InvalidKeyError('Not an HMAC key') + + return base64url_decode(obj['k']) + def sign(self, msg, key): return hmac.new(key, msg, self.hash_alg).digest() @@ -156,9 +189,8 @@ if has_crypto: isinstance(key, RSAPublicKey): return key - if is_string_type(key): - if not isinstance(key, binary_type): - key = key.encode('utf-8') + if isinstance(key, string_types): + key = force_bytes(key) try: if key.startswith(b'ssh-rsa'): @@ -172,6 +204,105 @@ if has_crypto: return key + @staticmethod + def to_jwk(key_obj): + obj = None + + if getattr(key_obj, 'private_numbers', None): + # Private key + numbers = key_obj.private_numbers() + + obj = { + 'kty': 'RSA', + 'key_ops': ['sign'], + 'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)), + 'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)), + 'd': force_unicode(to_base64url_uint(numbers.d)), + 'p': force_unicode(to_base64url_uint(numbers.p)), + 'q': force_unicode(to_base64url_uint(numbers.q)), + 'dp': force_unicode(to_base64url_uint(numbers.dmp1)), + 'dq': force_unicode(to_base64url_uint(numbers.dmq1)), + 'qi': force_unicode(to_base64url_uint(numbers.iqmp)) + } + + elif getattr(key_obj, 'verifier', None): + # Public key + numbers = key_obj.public_numbers() + + obj = { + 'kty': 'RSA', + 'key_ops': ['verify'], + 'n': force_unicode(to_base64url_uint(numbers.n)), + 'e': force_unicode(to_base64url_uint(numbers.e)) + } + else: + raise InvalidKeyError('Not a public or private key') + + return json.dumps(obj) + + @staticmethod + def from_jwk(jwk): + try: + obj = json.loads(jwk) + except ValueError: + raise InvalidKeyError('Key is not valid JSON') + + if obj.get('kty') != 'RSA': + raise InvalidKeyError('Not an RSA key') + + if 'd' in obj and 'e' in obj and 'n' in obj: + # Private key + if 'oth' in obj: + raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported') + + other_props = ['p', 'q', 'dp', 'dq', 'qi'] + props_found = [prop in obj for prop in other_props] + any_props_found = any(props_found) + + if any_props_found and not all(props_found): + raise InvalidKeyError('RSA key must include all parameters if any are present besides d') + + public_numbers = RSAPublicNumbers( + from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) + ) + + if any_props_found: + numbers = RSAPrivateNumbers( + d=from_base64url_uint(obj['d']), + p=from_base64url_uint(obj['p']), + q=from_base64url_uint(obj['q']), + dmp1=from_base64url_uint(obj['dp']), + dmq1=from_base64url_uint(obj['dq']), + iqmp=from_base64url_uint(obj['qi']), + public_numbers=public_numbers + ) + else: + d = from_base64url_uint(obj['d']) + p, q = rsa_recover_prime_factors( + public_numbers.n, d, public_numbers.e + ) + + numbers = RSAPrivateNumbers( + d=d, + p=p, + q=q, + dmp1=rsa_crt_dmp1(d, p), + dmq1=rsa_crt_dmq1(d, q), + iqmp=rsa_crt_iqmp(p, q), + public_numbers=public_numbers + ) + + return numbers.private_key(default_backend()) + elif 'n' in obj and 'e' in obj: + # Public key + numbers = RSAPublicNumbers( + from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) + ) + + return numbers.public_key(default_backend()) + else: + raise InvalidKeyError('Not a public or private key') + def sign(self, msg, key): signer = key.signer( padding.PKCS1v15(), @@ -213,9 +344,8 @@ if has_crypto: isinstance(key, EllipticCurvePublicKey): return key - if is_string_type(key): - if not isinstance(key, binary_type): - key = key.encode('utf-8') + if isinstance(key, string_types): + key = force_bytes(key) # Attempt to load key. We don't know if it's # a Signing Key or a Verifying Key, so we try |