summaryrefslogtreecommitdiff
path: root/jwt/algorithms.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r--jwt/algorithms.py29
1 files changed, 19 insertions, 10 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 277f940..bc928fe 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,6 +1,7 @@
import hashlib
import hmac
import json
+from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Type, Union
from .exceptions import InvalidKeyError
@@ -117,7 +118,7 @@ def get_default_algorithms() -> Dict[str, "Algorithm"]:
return default_algorithms
-class Algorithm:
+class Algorithm(ABC):
"""
The interface for an algorithm used to sign and verify tokens.
"""
@@ -148,40 +149,40 @@ class Algorithm:
# variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
# that may still be poorly supported.
+ @abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
- raise NotImplementedError
+ @abstractmethod
def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
- raise NotImplementedError
+ @abstractmethod
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
"""
- raise NotImplementedError
@staticmethod
+ @abstractmethod
def to_jwk(key_obj) -> JWKDict:
"""
Serializes a given RSA key into a JWK
"""
- raise NotImplementedError
@staticmethod
+ @abstractmethod
def from_jwk(jwk: JWKDict):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
- raise NotImplementedError
class NoneAlgorithm(Algorithm):
@@ -205,6 +206,14 @@ class NoneAlgorithm(Algorithm):
def verify(self, msg, key, sig):
return False
+ @staticmethod
+ def to_jwk(key_obj) -> JWKDict:
+ raise NotImplementedError()
+
+ @staticmethod
+ def from_jwk(jwk: JWKDict):
+ raise NotImplementedError()
+
class HMACAlgorithm(Algorithm):
"""
@@ -299,7 +308,7 @@ if has_crypto:
def to_jwk(key_obj):
obj = None
- if getattr(key_obj, "private_numbers", None):
+ if hasattr(key_obj, "private_numbers"):
# Private key
numbers = key_obj.private_numbers()
@@ -316,7 +325,7 @@ if has_crypto:
"qi": to_base64url_uint(numbers.iqmp).decode(),
}
- elif getattr(key_obj, "verify", None):
+ elif hasattr(key_obj, "verify"):
# Public key
numbers = key_obj.public_numbers()
@@ -587,7 +596,7 @@ if has_crypto:
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg.digest_size,
+ salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
@@ -599,7 +608,7 @@ if has_crypto:
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg.digest_size,
+ salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)