diff options
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r-- | jwt/algorithms.py | 189 |
1 files changed, 107 insertions, 82 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py index bc928fe..4bcf81b 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import hashlib import hmac import json from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, Type, Union +from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, cast from .exceptions import InvalidKeyError from .types import HashlibHash, JWKDict @@ -19,7 +21,6 @@ from .utils import ( ) try: - import cryptography.exceptions from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes @@ -68,6 +69,23 @@ try: except ModuleNotFoundError: has_crypto = False + +if TYPE_CHECKING: + # Type aliases for convenience in algorithms method signatures + AllowedRSAKeys = RSAPrivateKey | RSAPublicKey + AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey + AllowedOKPKeys = ( + Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey + ) + AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys + AllowedPrivateKeys = ( + RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey + ) + AllowedPublicKeys = ( + RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey + ) + + requires_cryptography = { "RS256", "RS384", @@ -84,7 +102,7 @@ requires_cryptography = { } -def get_default_algorithms() -> Dict[str, "Algorithm"]: +def get_default_algorithms() -> dict[str, Algorithm]: """ Returns the algorithms that are implemented by the library. """ @@ -145,10 +163,6 @@ class Algorithm(ABC): else: return bytes(hash_alg(bytestr).digest()) - # TODO: all key-related `Any`s in this class should optimally be made - # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605 - # that may still be poorly supported. - @abstractmethod def prepare_key(self, key: Any) -> Any: """ @@ -172,16 +186,16 @@ class Algorithm(ABC): @staticmethod @abstractmethod - def to_jwk(key_obj) -> JWKDict: + def to_jwk(key_obj) -> str: """ - Serializes a given RSA key into a JWK + Serializes a given key into a JWK """ @staticmethod @abstractmethod - def from_jwk(jwk: JWKDict): + def from_jwk(jwk: str | JWKDict) -> Any: """ - Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object + Deserializes a given key from JWK back into a key object """ @@ -191,7 +205,7 @@ class NoneAlgorithm(Algorithm): operations are required. """ - def prepare_key(self, key): + def prepare_key(self, key: str | None) -> None: if key == "": key = None @@ -200,18 +214,18 @@ class NoneAlgorithm(Algorithm): return key - def sign(self, msg, key): + def sign(self, msg: bytes, key: None) -> bytes: return b"" - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: None, sig: bytes) -> bool: return False @staticmethod - def to_jwk(key_obj) -> JWKDict: + def to_jwk(key_obj: Any) -> NoReturn: raise NotImplementedError() @staticmethod - def from_jwk(jwk: JWKDict): + def from_jwk(jwk: str | JWKDict) -> NoReturn: raise NotImplementedError() @@ -228,19 +242,19 @@ class HMACAlgorithm(Algorithm): def __init__(self, hash_alg: HashlibHash) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): - key = force_bytes(key) + def prepare_key(self, key: str | bytes) -> bytes: + key_bytes = force_bytes(key) - if is_pem_format(key) or is_ssh_key(key): + if is_pem_format(key_bytes) or is_ssh_key(key_bytes): raise InvalidKeyError( "The specified key is an asymmetric key or x509 certificate and" " should not be used as an HMAC secret." ) - return key + return key_bytes @staticmethod - def to_jwk(key_obj): + def to_jwk(key_obj: str | bytes) -> str: return json.dumps( { "k": base64url_encode(force_bytes(key_obj)).decode(), @@ -249,10 +263,10 @@ class HMACAlgorithm(Algorithm): ) @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: str | JWKDict) -> bytes: try: if isinstance(jwk, str): - obj = json.loads(jwk) + obj: JWKDict = json.loads(jwk) elif isinstance(jwk, dict): obj = jwk else: @@ -268,7 +282,7 @@ class HMACAlgorithm(Algorithm): def sign(self, msg: bytes, key: bytes) -> bytes: return hmac.new(key, msg, self.hash_alg).digest() - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: return hmac.compare_digest(sig, self.sign(msg, key)) @@ -280,14 +294,14 @@ if has_crypto: RSASSA-PKCS-v1_5 and the specified hash function. """ - SHA256: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA256 - SHA384: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA384 - SHA512: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA512 + SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 + SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 + SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 - def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None: + def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): + def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: if isinstance(key, (RSAPrivateKey, RSAPublicKey)): return key @@ -298,15 +312,17 @@ if has_crypto: try: if key_bytes.startswith(b"ssh-rsa"): - return load_ssh_public_key(key_bytes) + return cast(RSAPublicKey, load_ssh_public_key(key_bytes)) else: - return load_pem_private_key(key_bytes, password=None) + return cast( + RSAPrivateKey, load_pem_private_key(key_bytes, password=None) + ) except ValueError: - return load_pem_public_key(key_bytes) + return cast(RSAPublicKey, load_pem_public_key(key_bytes)) @staticmethod - def to_jwk(key_obj): - obj = None + def to_jwk(key_obj: AllowedRSAKeys) -> str: + obj: dict[str, Any] | None = None if hasattr(key_obj, "private_numbers"): # Private key @@ -341,7 +357,7 @@ if has_crypto: return json.dumps(obj) @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -412,10 +428,10 @@ if has_crypto: else: raise InvalidKeyError("Not a public or private key") - def sign(self, msg, key): + def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: try: key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) return True @@ -428,14 +444,14 @@ if has_crypto: ECDSA and the specified hash function """ - SHA256: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA256 - SHA384: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA384 - SHA512: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA512 + SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 + SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 + SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 - def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None: + def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): + def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): return key @@ -449,41 +465,46 @@ if has_crypto: # the Verifying Key first. try: if key_bytes.startswith(b"ecdsa-sha2-"): - key = load_ssh_public_key(key_bytes) + crypto_key = load_ssh_public_key(key_bytes) else: - key = load_pem_public_key(key_bytes) + crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment] except ValueError: - key = load_pem_private_key(key_bytes, password=None) + crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] # Explicit check the key to prevent confusing errors from cryptography - if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): + if not isinstance( + crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey) + ): raise InvalidKeyError( "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" ) - return key + return crypto_key - def sign(self, msg, key): + def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: der_sig = key.sign(msg, ECDSA(self.hash_alg())) return der_to_raw_signature(der_sig, key.curve) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool: try: der_sig = raw_to_der_signature(sig, key.curve) except ValueError: return False try: - if isinstance(key, EllipticCurvePrivateKey): - key = key.public_key() - key.verify(der_sig, msg, ECDSA(self.hash_alg())) + public_key = ( + key.public_key() + if isinstance(key, EllipticCurvePrivateKey) + else key + ) + public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) return True except InvalidSignature: return False @staticmethod - def to_jwk(key_obj): + def to_jwk(key_obj: AllowedECKeys) -> str: if isinstance(key_obj, EllipticCurvePrivateKey): public_numbers = key_obj.public_key().public_numbers() elif isinstance(key_obj, EllipticCurvePublicKey): @@ -502,7 +523,7 @@ if has_crypto: else: raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") - obj = { + obj: dict[str, Any] = { "kty": "EC", "crv": crv, "x": to_base64url_uint(public_numbers.x).decode(), @@ -517,9 +538,7 @@ if has_crypto: return json.dumps(obj) @staticmethod - def from_jwk( - jwk: Any, - ) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]: + def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -591,7 +610,7 @@ if has_crypto: Performs a signature using RSASSA-PSS with MGF1 """ - def sign(self, msg, key): + def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: return key.sign( msg, padding.PSS( @@ -601,7 +620,7 @@ if has_crypto: self.hash_alg(), ) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: try: key.verify( sig, @@ -623,21 +642,20 @@ if has_crypto: This class requires ``cryptography>=2.6`` to be installed. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: pass - def prepare_key(self, key): + def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: if isinstance(key, (bytes, str)): - if isinstance(key, str): - key = key.encode("utf-8") - str_key = key.decode("utf-8") + key_str = key.decode("utf-8") if isinstance(key, bytes) else key + key_bytes = key.encode("utf-8") if isinstance(key, str) else key - if "-----BEGIN PUBLIC" in str_key: - key = load_pem_public_key(key) - elif "-----BEGIN PRIVATE" in str_key: - key = load_pem_private_key(key, password=None) - elif str_key[0:4] == "ssh-": - key = load_ssh_public_key(key) + if "-----BEGIN PUBLIC" in key_str: + key = load_pem_public_key(key_bytes) # type: ignore[assignment] + elif "-----BEGIN PRIVATE" in key_str: + key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] + elif key_str[0:4] == "ssh-": + key = load_ssh_public_key(key_bytes) # type: ignore[assignment] # Explicit check the key to prevent confusing errors from cryptography if not isinstance( @@ -650,7 +668,9 @@ if has_crypto: return key - def sign(self, msg, key): + def sign( + self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey + ) -> bytes: """ Sign a message ``msg`` using the EdDSA private key ``key`` :param str|bytes msg: Message to sign @@ -658,10 +678,12 @@ if has_crypto: or :class:`.Ed448PrivateKey` isinstance :return bytes signature: The signature, as bytes """ - msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg - return key.sign(msg) + msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg + return key.sign(msg_bytes) - def verify(self, msg, key, sig): + def verify( + self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes + ) -> bool: """ Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` @@ -672,18 +694,21 @@ if has_crypto: :return bool verified: True if signature is valid, False if not. """ try: - msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg - sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig + msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg + sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig - if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): - key = key.public_key() - key.verify(sig, msg) + public_key = ( + key.public_key() + if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) + else key + ) + public_key.verify(sig_bytes, msg_bytes) return True # If no exception was raised, the signature is valid. - except cryptography.exceptions.InvalidSignature: + except InvalidSignature: return False @staticmethod - def to_jwk(key): + def to_jwk(key: AllowedOKPKeys) -> str: if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): x = key.public_bytes( encoding=Encoding.Raw, @@ -723,7 +748,7 @@ if has_crypto: raise InvalidKeyError("Not a public or private key") @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) |