summaryrefslogtreecommitdiff
path: root/jwt
diff options
context:
space:
mode:
authorViicos <65306057+Viicos@users.noreply.github.com>2023-03-07 03:22:39 +0100
committerGitHub <noreply@github.com>2023-03-06 21:22:39 -0500
commit777efa2f51249f63b0f95804230117723eca5d09 (patch)
treef8d7db72d0806380daa91b37edde3c27c32eedea /jwt
parent5a2a6b640d9780c75817a626672da0b2d38f9e78 (diff)
downloadpyjwt-777efa2f51249f63b0f95804230117723eca5d09.tar.gz
Make `Algorithm` an abstract base class (#845)
* Make `Algorithm` an abstract base class This also removes some tests that are not relevant anymore Raise `NotImplementedError` for `NoneAlgorithm` * Use `hasattr` instead of `getattr` * Only allow `dict` in `encode`
Diffstat (limited to 'jwt')
-rw-r--r--jwt/algorithms.py29
-rw-r--r--jwt/api_jwt.py8
2 files changed, 23 insertions, 14 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 277f940..bc928fe 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,6 +1,7 @@
import hashlib
import hmac
import json
+from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Type, Union
from .exceptions import InvalidKeyError
@@ -117,7 +118,7 @@ def get_default_algorithms() -> Dict[str, "Algorithm"]:
return default_algorithms
-class Algorithm:
+class Algorithm(ABC):
"""
The interface for an algorithm used to sign and verify tokens.
"""
@@ -148,40 +149,40 @@ class Algorithm:
# variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
# that may still be poorly supported.
+ @abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
- raise NotImplementedError
+ @abstractmethod
def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
- raise NotImplementedError
+ @abstractmethod
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
"""
- raise NotImplementedError
@staticmethod
+ @abstractmethod
def to_jwk(key_obj) -> JWKDict:
"""
Serializes a given RSA key into a JWK
"""
- raise NotImplementedError
@staticmethod
+ @abstractmethod
def from_jwk(jwk: JWKDict):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
- raise NotImplementedError
class NoneAlgorithm(Algorithm):
@@ -205,6 +206,14 @@ class NoneAlgorithm(Algorithm):
def verify(self, msg, key, sig):
return False
+ @staticmethod
+ def to_jwk(key_obj) -> JWKDict:
+ raise NotImplementedError()
+
+ @staticmethod
+ def from_jwk(jwk: JWKDict):
+ raise NotImplementedError()
+
class HMACAlgorithm(Algorithm):
"""
@@ -299,7 +308,7 @@ if has_crypto:
def to_jwk(key_obj):
obj = None
- if getattr(key_obj, "private_numbers", None):
+ if hasattr(key_obj, "private_numbers"):
# Private key
numbers = key_obj.private_numbers()
@@ -316,7 +325,7 @@ if has_crypto:
"qi": to_base64url_uint(numbers.iqmp).decode(),
}
- elif getattr(key_obj, "verify", None):
+ elif hasattr(key_obj, "verify"):
# Public key
numbers = key_obj.public_numbers()
@@ -587,7 +596,7 @@ if has_crypto:
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg.digest_size,
+ salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
@@ -599,7 +608,7 @@ if has_crypto:
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg.digest_size,
+ salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 5664949..d85f6e8 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import json
import warnings
from calendar import timegm
-from collections.abc import Iterable, Mapping
+from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Type, Union
@@ -47,10 +47,10 @@ class PyJWT:
json_encoder: Optional[Type[json.JSONEncoder]] = None,
sort_headers: bool = True,
) -> str:
- # Check that we get a mapping
- if not isinstance(payload, Mapping):
+ # Check that we get a dict
+ if not isinstance(payload, dict):
raise TypeError(
- "Expecting a mapping object, as JWT only supports "
+ "Expecting a dict object, as JWT only supports "
"JSON objects as payloads."
)