summaryrefslogtreecommitdiff
path: root/jwt/algorithms.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r--jwt/algorithms.py189
1 files changed, 107 insertions, 82 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index bc928fe..4bcf81b 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
import hashlib
import hmac
import json
from abc import ABC, abstractmethod
-from typing import Any, ClassVar, Dict, Type, Union
+from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, cast
from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict
@@ -19,7 +21,6 @@ from .utils import (
)
try:
- import cryptography.exceptions
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
@@ -68,6 +69,23 @@ try:
except ModuleNotFoundError:
has_crypto = False
+
+if TYPE_CHECKING:
+ # Type aliases for convenience in algorithms method signatures
+ AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
+ AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
+ AllowedOKPKeys = (
+ Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
+ )
+ AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
+ AllowedPrivateKeys = (
+ RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
+ )
+ AllowedPublicKeys = (
+ RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
+ )
+
+
requires_cryptography = {
"RS256",
"RS384",
@@ -84,7 +102,7 @@ requires_cryptography = {
}
-def get_default_algorithms() -> Dict[str, "Algorithm"]:
+def get_default_algorithms() -> dict[str, Algorithm]:
"""
Returns the algorithms that are implemented by the library.
"""
@@ -145,10 +163,6 @@ class Algorithm(ABC):
else:
return bytes(hash_alg(bytestr).digest())
- # TODO: all key-related `Any`s in this class should optimally be made
- # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
- # that may still be poorly supported.
-
@abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
@@ -172,16 +186,16 @@ class Algorithm(ABC):
@staticmethod
@abstractmethod
- def to_jwk(key_obj) -> JWKDict:
+ def to_jwk(key_obj) -> str:
"""
- Serializes a given RSA key into a JWK
+ Serializes a given key into a JWK
"""
@staticmethod
@abstractmethod
- def from_jwk(jwk: JWKDict):
+ def from_jwk(jwk: str | JWKDict) -> Any:
"""
- Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
+ Deserializes a given key from JWK back into a key object
"""
@@ -191,7 +205,7 @@ class NoneAlgorithm(Algorithm):
operations are required.
"""
- def prepare_key(self, key):
+ def prepare_key(self, key: str | None) -> None:
if key == "":
key = None
@@ -200,18 +214,18 @@ class NoneAlgorithm(Algorithm):
return key
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: None) -> bytes:
return b""
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
return False
@staticmethod
- def to_jwk(key_obj) -> JWKDict:
+ def to_jwk(key_obj: Any) -> NoReturn:
raise NotImplementedError()
@staticmethod
- def from_jwk(jwk: JWKDict):
+ def from_jwk(jwk: str | JWKDict) -> NoReturn:
raise NotImplementedError()
@@ -228,19 +242,19 @@ class HMACAlgorithm(Algorithm):
def __init__(self, hash_alg: HashlibHash) -> None:
self.hash_alg = hash_alg
- def prepare_key(self, key):
- key = force_bytes(key)
+ def prepare_key(self, key: str | bytes) -> bytes:
+ key_bytes = force_bytes(key)
- if is_pem_format(key) or is_ssh_key(key):
+ if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
raise InvalidKeyError(
"The specified key is an asymmetric key or x509 certificate and"
" should not be used as an HMAC secret."
)
- return key
+ return key_bytes
@staticmethod
- def to_jwk(key_obj):
+ def to_jwk(key_obj: str | bytes) -> str:
return json.dumps(
{
"k": base64url_encode(force_bytes(key_obj)).decode(),
@@ -249,10 +263,10 @@ class HMACAlgorithm(Algorithm):
)
@staticmethod
- def from_jwk(jwk):
+ def from_jwk(jwk: str | JWKDict) -> bytes:
try:
if isinstance(jwk, str):
- obj = json.loads(jwk)
+ obj: JWKDict = json.loads(jwk)
elif isinstance(jwk, dict):
obj = jwk
else:
@@ -268,7 +282,7 @@ class HMACAlgorithm(Algorithm):
def sign(self, msg: bytes, key: bytes) -> bytes:
return hmac.new(key, msg, self.hash_alg).digest()
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
return hmac.compare_digest(sig, self.sign(msg, key))
@@ -280,14 +294,14 @@ if has_crypto:
RSASSA-PKCS-v1_5 and the specified hash function.
"""
- SHA256: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA256
- SHA384: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA384
- SHA512: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA512
+ SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
+ SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
+ SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
- def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None:
+ def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg
- def prepare_key(self, key):
+ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
return key
@@ -298,15 +312,17 @@ if has_crypto:
try:
if key_bytes.startswith(b"ssh-rsa"):
- return load_ssh_public_key(key_bytes)
+ return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
else:
- return load_pem_private_key(key_bytes, password=None)
+ return cast(
+ RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
+ )
except ValueError:
- return load_pem_public_key(key_bytes)
+ return cast(RSAPublicKey, load_pem_public_key(key_bytes))
@staticmethod
- def to_jwk(key_obj):
- obj = None
+ def to_jwk(key_obj: AllowedRSAKeys) -> str:
+ obj: dict[str, Any] | None = None
if hasattr(key_obj, "private_numbers"):
# Private key
@@ -341,7 +357,7 @@ if has_crypto:
return json.dumps(obj)
@staticmethod
- def from_jwk(jwk):
+ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
@@ -412,10 +428,10 @@ if has_crypto:
else:
raise InvalidKeyError("Not a public or private key")
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
try:
key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
return True
@@ -428,14 +444,14 @@ if has_crypto:
ECDSA and the specified hash function
"""
- SHA256: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA256
- SHA384: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA384
- SHA512: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA512
+ SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
+ SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
+ SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
- def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None:
+ def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg
- def prepare_key(self, key):
+ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
return key
@@ -449,41 +465,46 @@ if has_crypto:
# the Verifying Key first.
try:
if key_bytes.startswith(b"ecdsa-sha2-"):
- key = load_ssh_public_key(key_bytes)
+ crypto_key = load_ssh_public_key(key_bytes)
else:
- key = load_pem_public_key(key_bytes)
+ crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
except ValueError:
- key = load_pem_private_key(key_bytes, password=None)
+ crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
# Explicit check the key to prevent confusing errors from cryptography
- if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
+ if not isinstance(
+ crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
+ ):
raise InvalidKeyError(
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
)
- return key
+ return crypto_key
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
der_sig = key.sign(msg, ECDSA(self.hash_alg()))
return der_to_raw_signature(der_sig, key.curve)
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
try:
der_sig = raw_to_der_signature(sig, key.curve)
except ValueError:
return False
try:
- if isinstance(key, EllipticCurvePrivateKey):
- key = key.public_key()
- key.verify(der_sig, msg, ECDSA(self.hash_alg()))
+ public_key = (
+ key.public_key()
+ if isinstance(key, EllipticCurvePrivateKey)
+ else key
+ )
+ public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
@staticmethod
- def to_jwk(key_obj):
+ def to_jwk(key_obj: AllowedECKeys) -> str:
if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey):
@@ -502,7 +523,7 @@ if has_crypto:
else:
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
- obj = {
+ obj: dict[str, Any] = {
"kty": "EC",
"crv": crv,
"x": to_base64url_uint(public_numbers.x).decode(),
@@ -517,9 +538,7 @@ if has_crypto:
return json.dumps(obj)
@staticmethod
- def from_jwk(
- jwk: Any,
- ) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]:
+ def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
@@ -591,7 +610,7 @@ if has_crypto:
Performs a signature using RSASSA-PSS with MGF1
"""
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
return key.sign(
msg,
padding.PSS(
@@ -601,7 +620,7 @@ if has_crypto:
self.hash_alg(),
)
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
try:
key.verify(
sig,
@@ -623,21 +642,20 @@ if has_crypto:
This class requires ``cryptography>=2.6`` to be installed.
"""
- def __init__(self, **kwargs) -> None:
+ def __init__(self, **kwargs: Any) -> None:
pass
- def prepare_key(self, key):
+ def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
if isinstance(key, (bytes, str)):
- if isinstance(key, str):
- key = key.encode("utf-8")
- str_key = key.decode("utf-8")
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
+ key_bytes = key.encode("utf-8") if isinstance(key, str) else key
- if "-----BEGIN PUBLIC" in str_key:
- key = load_pem_public_key(key)
- elif "-----BEGIN PRIVATE" in str_key:
- key = load_pem_private_key(key, password=None)
- elif str_key[0:4] == "ssh-":
- key = load_ssh_public_key(key)
+ if "-----BEGIN PUBLIC" in key_str:
+ key = load_pem_public_key(key_bytes) # type: ignore[assignment]
+ elif "-----BEGIN PRIVATE" in key_str:
+ key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
+ elif key_str[0:4] == "ssh-":
+ key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(
@@ -650,7 +668,9 @@ if has_crypto:
return key
- def sign(self, msg, key):
+ def sign(
+ self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
+ ) -> bytes:
"""
Sign a message ``msg`` using the EdDSA private key ``key``
:param str|bytes msg: Message to sign
@@ -658,10 +678,12 @@ if has_crypto:
or :class:`.Ed448PrivateKey` isinstance
:return bytes signature: The signature, as bytes
"""
- msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
- return key.sign(msg)
+ msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
+ return key.sign(msg_bytes)
- def verify(self, msg, key, sig):
+ def verify(
+ self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
+ ) -> bool:
"""
Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
@@ -672,18 +694,21 @@ if has_crypto:
:return bool verified: True if signature is valid, False if not.
"""
try:
- msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
- sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig
+ msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
+ sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
- if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
- key = key.public_key()
- key.verify(sig, msg)
+ public_key = (
+ key.public_key()
+ if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
+ else key
+ )
+ public_key.verify(sig_bytes, msg_bytes)
return True # If no exception was raised, the signature is valid.
- except cryptography.exceptions.InvalidSignature:
+ except InvalidSignature:
return False
@staticmethod
- def to_jwk(key):
+ def to_jwk(key: AllowedOKPKeys) -> str:
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
x = key.public_bytes(
encoding=Encoding.Raw,
@@ -723,7 +748,7 @@ if has_crypto:
raise InvalidKeyError("Not a public or private key")
@staticmethod
- def from_jwk(jwk):
+ def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)