summaryrefslogtreecommitdiff
path: root/jwt/algorithms.py
diff options
context:
space:
mode:
authorJosé Padilla <jpadilla@webapplicate.com>2019-10-21 22:38:34 -0400
committerGitHub <noreply@github.com>2019-10-21 22:38:34 -0400
commit11ac89474b1179925c76450fcc4b3d2042c45f19 (patch)
treedda6b15326cab750f5e52ee57e3125bb5d0a7eee /jwt/algorithms.py
parentae080f472c913ad94456fd9e10b05ec2d038b7cc (diff)
downloadpyjwt-11ac89474b1179925c76450fcc4b3d2042c45f19.tar.gz
DX Tweaks (#450)
* Setup pre-commit hooks * Run initial `tox -e lint` * Fix package name * Fix .travis.yml
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r--jwt/algorithms.py251
1 files changed, 154 insertions, 97 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 1343688..293a470 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -2,26 +2,39 @@ import hashlib
import hmac
import json
-
from .compat import constant_time_compare, string_types
from .exceptions import InvalidKeyError
from .utils import (
- base64url_decode, base64url_encode, der_to_raw_signature,
- force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
- to_base64url_uint
+ base64url_decode,
+ base64url_encode,
+ der_to_raw_signature,
+ force_bytes,
+ force_unicode,
+ from_base64url_uint,
+ raw_to_der_signature,
+ to_base64url_uint,
)
try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import (
- load_pem_private_key, load_pem_public_key, load_ssh_public_key
+ load_pem_private_key,
+ load_pem_public_key,
+ load_ssh_public_key,
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
- RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
- rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
+ RSAPrivateKey,
+ RSAPublicKey,
+ RSAPrivateNumbers,
+ RSAPublicNumbers,
+ rsa_recover_prime_factors,
+ rsa_crt_dmp1,
+ rsa_crt_dmq1,
+ rsa_crt_iqmp,
)
from cryptography.hazmat.primitives.asymmetric.ec import (
- EllipticCurvePrivateKey, EllipticCurvePublicKey
+ EllipticCurvePrivateKey,
+ EllipticCurvePublicKey,
)
from cryptography.hazmat.primitives.asymmetric import ec, padding
from cryptography.hazmat.backends import default_backend
@@ -31,8 +44,20 @@ try:
except ImportError:
has_crypto = False
-requires_cryptography = set(['RS256', 'RS384', 'RS512', 'ES256', 'ES384',
- 'ES521', 'ES512', 'PS256', 'PS384', 'PS512'])
+requires_cryptography = set(
+ [
+ "RS256",
+ "RS384",
+ "RS512",
+ "ES256",
+ "ES384",
+ "ES521",
+ "ES512",
+ "PS256",
+ "PS384",
+ "PS512",
+ ]
+)
def get_default_algorithms():
@@ -40,25 +65,29 @@ def get_default_algorithms():
Returns the algorithms that are implemented by the library.
"""
default_algorithms = {
- 'none': NoneAlgorithm(),
- 'HS256': HMACAlgorithm(HMACAlgorithm.SHA256),
- 'HS384': HMACAlgorithm(HMACAlgorithm.SHA384),
- 'HS512': HMACAlgorithm(HMACAlgorithm.SHA512)
+ "none": NoneAlgorithm(),
+ "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
+ "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
+ "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
}
if has_crypto:
- default_algorithms.update({
- 'RS256': RSAAlgorithm(RSAAlgorithm.SHA256),
- 'RS384': RSAAlgorithm(RSAAlgorithm.SHA384),
- 'RS512': RSAAlgorithm(RSAAlgorithm.SHA512),
- 'ES256': ECAlgorithm(ECAlgorithm.SHA256),
- 'ES384': ECAlgorithm(ECAlgorithm.SHA384),
- 'ES521': ECAlgorithm(ECAlgorithm.SHA512),
- 'ES512': ECAlgorithm(ECAlgorithm.SHA512), # Backward compat for #219 fix
- 'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
- 'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
- 'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512)
- })
+ default_algorithms.update(
+ {
+ "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
+ "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
+ "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
+ "ES256": ECAlgorithm(ECAlgorithm.SHA256),
+ "ES384": ECAlgorithm(ECAlgorithm.SHA384),
+ "ES521": ECAlgorithm(ECAlgorithm.SHA512),
+ "ES512": ECAlgorithm(
+ ECAlgorithm.SHA512
+ ), # Backward compat for #219 fix
+ "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
+ "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
+ "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
+ }
+ )
return default_algorithms
@@ -67,6 +96,7 @@ class Algorithm(object):
"""
The interface for an algorithm used to sign and verify tokens.
"""
+
def prepare_key(self, key):
"""
Performs necessary validation and conversions on the key and returns
@@ -108,8 +138,9 @@ class NoneAlgorithm(Algorithm):
Placeholder for use when no signing or verification
operations are required.
"""
+
def prepare_key(self, key):
- if key == '':
+ if key == "":
key = None
if key is not None:
@@ -118,7 +149,7 @@ class NoneAlgorithm(Algorithm):
return key
def sign(self, msg, key):
- return b''
+ return b""
def verify(self, msg, key, sig):
return False
@@ -129,6 +160,7 @@ class HMACAlgorithm(Algorithm):
Performs signing and verification operations using HMAC
and the specified hash function.
"""
+
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
@@ -140,34 +172,37 @@ class HMACAlgorithm(Algorithm):
key = force_bytes(key)
invalid_strings = [
- b'-----BEGIN PUBLIC KEY-----',
- b'-----BEGIN CERTIFICATE-----',
- b'-----BEGIN RSA PUBLIC KEY-----',
- b'ssh-rsa'
+ b"-----BEGIN PUBLIC KEY-----",
+ b"-----BEGIN CERTIFICATE-----",
+ b"-----BEGIN RSA PUBLIC KEY-----",
+ b"ssh-rsa",
]
if any([string_value in key for string_value in invalid_strings]):
raise InvalidKeyError(
- 'The specified key is an asymmetric key or x509 certificate and'
- ' should not be used as an HMAC secret.')
+ "The specified key is an asymmetric key or x509 certificate and"
+ " should not be used as an HMAC secret."
+ )
return key
@staticmethod
def to_jwk(key_obj):
- return json.dumps({
- 'k': force_unicode(base64url_encode(force_bytes(key_obj))),
- 'kty': 'oct'
- })
+ return json.dumps(
+ {
+ "k": force_unicode(base64url_encode(force_bytes(key_obj))),
+ "kty": "oct",
+ }
+ )
@staticmethod
def from_jwk(jwk):
obj = json.loads(jwk)
- if obj.get('kty') != 'oct':
- raise InvalidKeyError('Not an HMAC key')
+ if obj.get("kty") != "oct":
+ raise InvalidKeyError("Not an HMAC key")
- return base64url_decode(obj['k'])
+ return base64url_decode(obj["k"])
def sign(self, msg, key):
return hmac.new(key, msg, self.hash_alg).digest()
@@ -176,13 +211,14 @@ class HMACAlgorithm(Algorithm):
return constant_time_compare(sig, self.sign(msg, key))
-if has_crypto:
+if has_crypto: # noqa: C901
class RSAAlgorithm(Algorithm):
"""
Performs signing and verification operations using
RSASSA-PKCS-v1_5 and the specified hash function.
"""
+
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
@@ -191,22 +227,25 @@ if has_crypto:
self.hash_alg = hash_alg
def prepare_key(self, key):
- if isinstance(key, RSAPrivateKey) or \
- isinstance(key, RSAPublicKey):
+ if isinstance(key, RSAPrivateKey) or isinstance(key, RSAPublicKey):
return key
if isinstance(key, string_types):
key = force_bytes(key)
try:
- if key.startswith(b'ssh-rsa'):
- key = load_ssh_public_key(key, backend=default_backend())
+ if key.startswith(b"ssh-rsa"):
+ key = load_ssh_public_key(
+ key, backend=default_backend()
+ )
else:
- key = load_pem_private_key(key, password=None, backend=default_backend())
+ key = load_pem_private_key(
+ key, password=None, backend=default_backend()
+ )
except ValueError:
key = load_pem_public_key(key, backend=default_backend())
else:
- raise TypeError('Expecting a PEM-formatted key.')
+ raise TypeError("Expecting a PEM-formatted key.")
return key
@@ -214,35 +253,39 @@ if has_crypto:
def to_jwk(key_obj):
obj = None
- if getattr(key_obj, 'private_numbers', None):
+ if getattr(key_obj, "private_numbers", None):
# Private key
numbers = key_obj.private_numbers()
obj = {
- 'kty': 'RSA',
- 'key_ops': ['sign'],
- 'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)),
- 'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)),
- 'd': force_unicode(to_base64url_uint(numbers.d)),
- 'p': force_unicode(to_base64url_uint(numbers.p)),
- 'q': force_unicode(to_base64url_uint(numbers.q)),
- 'dp': force_unicode(to_base64url_uint(numbers.dmp1)),
- 'dq': force_unicode(to_base64url_uint(numbers.dmq1)),
- 'qi': force_unicode(to_base64url_uint(numbers.iqmp))
+ "kty": "RSA",
+ "key_ops": ["sign"],
+ "n": force_unicode(
+ to_base64url_uint(numbers.public_numbers.n)
+ ),
+ "e": force_unicode(
+ to_base64url_uint(numbers.public_numbers.e)
+ ),
+ "d": force_unicode(to_base64url_uint(numbers.d)),
+ "p": force_unicode(to_base64url_uint(numbers.p)),
+ "q": force_unicode(to_base64url_uint(numbers.q)),
+ "dp": force_unicode(to_base64url_uint(numbers.dmp1)),
+ "dq": force_unicode(to_base64url_uint(numbers.dmq1)),
+ "qi": force_unicode(to_base64url_uint(numbers.iqmp)),
}
- elif getattr(key_obj, 'verify', None):
+ elif getattr(key_obj, "verify", None):
# Public key
numbers = key_obj.public_numbers()
obj = {
- 'kty': 'RSA',
- 'key_ops': ['verify'],
- 'n': force_unicode(to_base64url_uint(numbers.n)),
- 'e': force_unicode(to_base64url_uint(numbers.e))
+ "kty": "RSA",
+ "key_ops": ["verify"],
+ "n": force_unicode(to_base64url_uint(numbers.n)),
+ "e": force_unicode(to_base64url_uint(numbers.e)),
}
else:
- raise InvalidKeyError('Not a public or private key')
+ raise InvalidKeyError("Not a public or private key")
return json.dumps(obj)
@@ -251,39 +294,44 @@ if has_crypto:
try:
obj = json.loads(jwk)
except ValueError:
- raise InvalidKeyError('Key is not valid JSON')
+ raise InvalidKeyError("Key is not valid JSON")
- if obj.get('kty') != 'RSA':
- raise InvalidKeyError('Not an RSA key')
+ if obj.get("kty") != "RSA":
+ raise InvalidKeyError("Not an RSA key")
- if 'd' in obj and 'e' in obj and 'n' in obj:
+ if "d" in obj and "e" in obj and "n" in obj:
# Private key
- if 'oth' in obj:
- raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
+ if "oth" in obj:
+ raise InvalidKeyError(
+ "Unsupported RSA private key: > 2 primes not supported"
+ )
- other_props = ['p', 'q', 'dp', 'dq', 'qi']
+ other_props = ["p", "q", "dp", "dq", "qi"]
props_found = [prop in obj for prop in other_props]
any_props_found = any(props_found)
if any_props_found and not all(props_found):
- raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
+ raise InvalidKeyError(
+ "RSA key must include all parameters if any are present besides d"
+ )
public_numbers = RSAPublicNumbers(
- from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
+ from_base64url_uint(obj["e"]),
+ from_base64url_uint(obj["n"]),
)
if any_props_found:
numbers = RSAPrivateNumbers(
- d=from_base64url_uint(obj['d']),
- p=from_base64url_uint(obj['p']),
- q=from_base64url_uint(obj['q']),
- dmp1=from_base64url_uint(obj['dp']),
- dmq1=from_base64url_uint(obj['dq']),
- iqmp=from_base64url_uint(obj['qi']),
- public_numbers=public_numbers
+ d=from_base64url_uint(obj["d"]),
+ p=from_base64url_uint(obj["p"]),
+ q=from_base64url_uint(obj["q"]),
+ dmp1=from_base64url_uint(obj["dp"]),
+ dmq1=from_base64url_uint(obj["dq"]),
+ iqmp=from_base64url_uint(obj["qi"]),
+ public_numbers=public_numbers,
)
else:
- d = from_base64url_uint(obj['d'])
+ d = from_base64url_uint(obj["d"])
p, q = rsa_recover_prime_factors(
public_numbers.n, d, public_numbers.e
)
@@ -295,19 +343,20 @@ if has_crypto:
dmp1=rsa_crt_dmp1(d, p),
dmq1=rsa_crt_dmq1(d, q),
iqmp=rsa_crt_iqmp(p, q),
- public_numbers=public_numbers
+ public_numbers=public_numbers,
)
return numbers.private_key(default_backend())
- elif 'n' in obj and 'e' in obj:
+ elif "n" in obj and "e" in obj:
# Public key
numbers = RSAPublicNumbers(
- from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
+ from_base64url_uint(obj["e"]),
+ from_base64url_uint(obj["n"]),
)
return numbers.public_key(default_backend())
else:
- raise InvalidKeyError('Not a public or private key')
+ raise InvalidKeyError("Not a public or private key")
def sign(self, msg, key):
return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
@@ -324,6 +373,7 @@ if has_crypto:
Performs signing and verification operations using
ECDSA and the specified hash function
"""
+
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
@@ -332,8 +382,9 @@ if has_crypto:
self.hash_alg = hash_alg
def prepare_key(self, key):
- if isinstance(key, EllipticCurvePrivateKey) or \
- isinstance(key, EllipticCurvePublicKey):
+ if isinstance(key, EllipticCurvePrivateKey) or isinstance(
+ key, EllipticCurvePublicKey
+ ):
return key
if isinstance(key, string_types):
@@ -343,15 +394,21 @@ if has_crypto:
# a Signing Key or a Verifying Key, so we try
# the Verifying Key first.
try:
- if key.startswith(b'ecdsa-sha2-'):
- key = load_ssh_public_key(key, backend=default_backend())
+ if key.startswith(b"ecdsa-sha2-"):
+ key = load_ssh_public_key(
+ key, backend=default_backend()
+ )
else:
- key = load_pem_public_key(key, backend=default_backend())
+ key = load_pem_public_key(
+ key, backend=default_backend()
+ )
except ValueError:
- key = load_pem_private_key(key, password=None, backend=default_backend())
+ key = load_pem_private_key(
+ key, password=None, backend=default_backend()
+ )
else:
- raise TypeError('Expecting a PEM-formatted key.')
+ raise TypeError("Expecting a PEM-formatted key.")
return key
@@ -382,9 +439,9 @@ if has_crypto:
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg.digest_size
+ salt_length=self.hash_alg.digest_size,
),
- self.hash_alg()
+ self.hash_alg(),
)
def verify(self, msg, key, sig):
@@ -394,9 +451,9 @@ if has_crypto:
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg.digest_size
+ salt_length=self.hash_alg.digest_size,
),
- self.hash_alg()
+ self.hash_alg(),
)
return True
except InvalidSignature: