diff options
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r-- | jwt/algorithms.py | 99 |
1 files changed, 57 insertions, 42 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 4fae441..c809f15 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,8 +1,10 @@ import hashlib import hmac import json +from typing import Any, Dict, Union from .exceptions import InvalidKeyError +from .types import JWKDict from .utils import ( base64url_decode, base64url_encode, @@ -20,10 +22,18 @@ try: from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes - from cryptography.hazmat.primitives.asymmetric import ec, padding + from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric.ec import ( + ECDSA, + SECP256K1, + SECP256R1, + SECP384R1, + SECP521R1, + EllipticCurve, EllipticCurvePrivateKey, + EllipticCurvePrivateNumbers, EllipticCurvePublicKey, + EllipticCurvePublicNumbers, ) from cryptography.hazmat.primitives.asymmetric.ed448 import ( Ed448PrivateKey, @@ -73,7 +83,7 @@ requires_cryptography = { } -def get_default_algorithms(): +def get_default_algorithms() -> Dict[str, "Algorithm"]: """ Returns the algorithms that are implemented by the library. """ @@ -130,25 +140,29 @@ class Algorithm: ): digest = hashes.Hash(hash_alg(), backend=default_backend()) digest.update(bytestr) - return digest.finalize() + return bytes(digest.finalize()) else: - return hash_alg(bytestr).digest() + return bytes(hash_alg(bytestr).digest()) - def prepare_key(self, key): + # TODO: all key-related `Any`s in this class should optimally be made + # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605 + # that may still be poorly supported. + + def prepare_key(self, key: Any) -> Any: """ Performs necessary validation and conversions on the key and returns the key value in the proper format for sign() and verify(). """ raise NotImplementedError - def sign(self, msg, key): + def sign(self, msg: bytes, key: Any) -> bytes: """ Returns a digital signature for the specified message using the specified key value. """ raise NotImplementedError - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: """ Verifies that the specified digital signature is valid for the specified message and key values. @@ -156,14 +170,14 @@ class Algorithm: raise NotImplementedError @staticmethod - def to_jwk(key_obj): + def to_jwk(key_obj) -> JWKDict: """ Serializes a given RSA key into a JWK """ raise NotImplementedError @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: JWKDict): """ Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object """ @@ -202,7 +216,7 @@ class HMACAlgorithm(Algorithm): SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 - def __init__(self, hash_alg): + def __init__(self, hash_alg) -> None: self.hash_alg = hash_alg def prepare_key(self, key): @@ -242,7 +256,7 @@ class HMACAlgorithm(Algorithm): return base64url_decode(obj["k"]) - def sign(self, msg, key): + def sign(self, msg: bytes, key: bytes) -> bytes: return hmac.new(key, msg, self.hash_alg).digest() def verify(self, msg, key, sig): @@ -261,7 +275,7 @@ if has_crypto: SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 - def __init__(self, hash_alg): + def __init__(self, hash_alg) -> None: self.hash_alg = hash_alg def prepare_key(self, key): @@ -271,16 +285,15 @@ if has_crypto: if not isinstance(key, (bytes, str)): raise TypeError("Expecting a PEM-formatted key.") - key = force_bytes(key) + key_bytes = force_bytes(key) try: - if key.startswith(b"ssh-rsa"): - key = load_ssh_public_key(key) + if key_bytes.startswith(b"ssh-rsa"): + return load_ssh_public_key(key_bytes) else: - key = load_pem_private_key(key, password=None) + return load_pem_private_key(key_bytes, password=None) except ValueError: - key = load_pem_public_key(key) - return key + return load_pem_public_key(key_bytes) @staticmethod def to_jwk(key_obj): @@ -383,12 +396,10 @@ if has_crypto: return numbers.private_key() elif "n" in obj and "e" in obj: # Public key - numbers = RSAPublicNumbers( + return RSAPublicNumbers( from_base64url_uint(obj["e"]), from_base64url_uint(obj["n"]), - ) - - return numbers.public_key() + ).public_key() else: raise InvalidKeyError("Not a public or private key") @@ -412,7 +423,7 @@ if has_crypto: SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 - def __init__(self, hash_alg): + def __init__(self, hash_alg) -> None: self.hash_alg = hash_alg def prepare_key(self, key): @@ -422,18 +433,18 @@ if has_crypto: if not isinstance(key, (bytes, str)): raise TypeError("Expecting a PEM-formatted key.") - key = force_bytes(key) + key_bytes = force_bytes(key) # Attempt to load key. We don't know if it's # a Signing Key or a Verifying Key, so we try # the Verifying Key first. try: - if key.startswith(b"ecdsa-sha2-"): - key = load_ssh_public_key(key) + if key_bytes.startswith(b"ecdsa-sha2-"): + key = load_ssh_public_key(key_bytes) else: - key = load_pem_public_key(key) + key = load_pem_public_key(key_bytes) except ValueError: - key = load_pem_private_key(key, password=None) + key = load_pem_private_key(key_bytes, password=None) # Explicit check the key to prevent confusing errors from cryptography if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): @@ -444,7 +455,7 @@ if has_crypto: return key def sign(self, msg, key): - der_sig = key.sign(msg, ec.ECDSA(self.hash_alg())) + der_sig = key.sign(msg, ECDSA(self.hash_alg())) return der_to_raw_signature(der_sig, key.curve) @@ -457,7 +468,7 @@ if has_crypto: try: if isinstance(key, EllipticCurvePrivateKey): key = key.public_key() - key.verify(der_sig, msg, ec.ECDSA(self.hash_alg())) + key.verify(der_sig, msg, ECDSA(self.hash_alg())) return True except InvalidSignature: return False @@ -472,13 +483,13 @@ if has_crypto: else: raise InvalidKeyError("Not a public or private key") - if isinstance(key_obj.curve, ec.SECP256R1): + if isinstance(key_obj.curve, SECP256R1): crv = "P-256" - elif isinstance(key_obj.curve, ec.SECP384R1): + elif isinstance(key_obj.curve, SECP384R1): crv = "P-384" - elif isinstance(key_obj.curve, ec.SECP521R1): + elif isinstance(key_obj.curve, SECP521R1): crv = "P-521" - elif isinstance(key_obj.curve, ec.SECP256K1): + elif isinstance(key_obj.curve, SECP256K1): crv = "secp256k1" else: raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") @@ -498,7 +509,9 @@ if has_crypto: return json.dumps(obj) @staticmethod - def from_jwk(jwk): + def from_jwk( + jwk: Any, + ) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -519,24 +532,26 @@ if has_crypto: y = base64url_decode(obj.get("y")) curve = obj.get("crv") + curve_obj: EllipticCurve + if curve == "P-256": if len(x) == len(y) == 32: - curve_obj = ec.SECP256R1() + curve_obj = SECP256R1() else: raise InvalidKeyError("Coords should be 32 bytes for curve P-256") elif curve == "P-384": if len(x) == len(y) == 48: - curve_obj = ec.SECP384R1() + curve_obj = SECP384R1() else: raise InvalidKeyError("Coords should be 48 bytes for curve P-384") elif curve == "P-521": if len(x) == len(y) == 66: - curve_obj = ec.SECP521R1() + curve_obj = SECP521R1() else: raise InvalidKeyError("Coords should be 66 bytes for curve P-521") elif curve == "secp256k1": if len(x) == len(y) == 32: - curve_obj = ec.SECP256K1() + curve_obj = SECP256K1() else: raise InvalidKeyError( "Coords should be 32 bytes for curve secp256k1" @@ -544,7 +559,7 @@ if has_crypto: else: raise InvalidKeyError(f"Invalid curve: {curve}") - public_numbers = ec.EllipticCurvePublicNumbers( + public_numbers = EllipticCurvePublicNumbers( x=int.from_bytes(x, byteorder="big"), y=int.from_bytes(y, byteorder="big"), curve=curve_obj, @@ -559,7 +574,7 @@ if has_crypto: "D should be {} bytes for curve {}", len(x), curve ) - return ec.EllipticCurvePrivateNumbers( + return EllipticCurvePrivateNumbers( int.from_bytes(d, byteorder="big"), public_numbers ).private_key() @@ -600,7 +615,7 @@ if has_crypto: This class requires ``cryptography>=2.6`` to be installed. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: pass def prepare_key(self, key): |