diff options
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r-- | jwt/algorithms.py | 29 |
1 files changed, 19 insertions, 10 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 277f940..bc928fe 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,6 +1,7 @@ import hashlib import hmac import json +from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, Type, Union from .exceptions import InvalidKeyError @@ -117,7 +118,7 @@ def get_default_algorithms() -> Dict[str, "Algorithm"]: return default_algorithms -class Algorithm: +class Algorithm(ABC): """ The interface for an algorithm used to sign and verify tokens. """ @@ -148,40 +149,40 @@ class Algorithm: # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605 # that may still be poorly supported. + @abstractmethod def prepare_key(self, key: Any) -> Any: """ Performs necessary validation and conversions on the key and returns the key value in the proper format for sign() and verify(). """ - raise NotImplementedError + @abstractmethod def sign(self, msg: bytes, key: Any) -> bytes: """ Returns a digital signature for the specified message using the specified key value. """ - raise NotImplementedError + @abstractmethod def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: """ Verifies that the specified digital signature is valid for the specified message and key values. """ - raise NotImplementedError @staticmethod + @abstractmethod def to_jwk(key_obj) -> JWKDict: """ Serializes a given RSA key into a JWK """ - raise NotImplementedError @staticmethod + @abstractmethod def from_jwk(jwk: JWKDict): """ Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object """ - raise NotImplementedError class NoneAlgorithm(Algorithm): @@ -205,6 +206,14 @@ class NoneAlgorithm(Algorithm): def verify(self, msg, key, sig): return False + @staticmethod + def to_jwk(key_obj) -> JWKDict: + raise NotImplementedError() + + @staticmethod + def from_jwk(jwk: JWKDict): + raise NotImplementedError() + class HMACAlgorithm(Algorithm): """ @@ -299,7 +308,7 @@ if has_crypto: def to_jwk(key_obj): obj = None - if getattr(key_obj, "private_numbers", None): + if hasattr(key_obj, "private_numbers"): # Private key numbers = key_obj.private_numbers() @@ -316,7 +325,7 @@ if has_crypto: "qi": to_base64url_uint(numbers.iqmp).decode(), } - elif getattr(key_obj, "verify", None): + elif hasattr(key_obj, "verify"): # Public key numbers = key_obj.public_numbers() @@ -587,7 +596,7 @@ if has_crypto: msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size, + salt_length=self.hash_alg().digest_size, ), self.hash_alg(), ) @@ -599,7 +608,7 @@ if has_crypto: msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size, + salt_length=self.hash_alg().digest_size, ), self.hash_alg(), ) |