summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorViicos <65306057+Viicos@users.noreply.github.com>2023-03-07 03:22:39 +0100
committerGitHub <noreply@github.com>2023-03-06 21:22:39 -0500
commit777efa2f51249f63b0f95804230117723eca5d09 (patch)
treef8d7db72d0806380daa91b37edde3c27c32eedea
parent5a2a6b640d9780c75817a626672da0b2d38f9e78 (diff)
downloadpyjwt-777efa2f51249f63b0f95804230117723eca5d09.tar.gz
Make `Algorithm` an abstract base class (#845)
* Make `Algorithm` an abstract base class This also removes some tests that are not relevant anymore Raise `NotImplementedError` for `NoneAlgorithm` * Use `hasattr` instead of `getattr` * Only allow `dict` in `encode`
-rw-r--r--jwt/algorithms.py29
-rw-r--r--jwt/api_jwt.py8
-rw-r--r--tests/test_algorithms.py46
-rw-r--r--tests/test_api_jws.py6
4 files changed, 37 insertions, 52 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 277f940..bc928fe 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,6 +1,7 @@
import hashlib
import hmac
import json
+from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Type, Union
from .exceptions import InvalidKeyError
@@ -117,7 +118,7 @@ def get_default_algorithms() -> Dict[str, "Algorithm"]:
return default_algorithms
-class Algorithm:
+class Algorithm(ABC):
"""
The interface for an algorithm used to sign and verify tokens.
"""
@@ -148,40 +149,40 @@ class Algorithm:
# variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
# that may still be poorly supported.
+ @abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
- raise NotImplementedError
+ @abstractmethod
def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
- raise NotImplementedError
+ @abstractmethod
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
"""
- raise NotImplementedError
@staticmethod
+ @abstractmethod
def to_jwk(key_obj) -> JWKDict:
"""
Serializes a given RSA key into a JWK
"""
- raise NotImplementedError
@staticmethod
+ @abstractmethod
def from_jwk(jwk: JWKDict):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
- raise NotImplementedError
class NoneAlgorithm(Algorithm):
@@ -205,6 +206,14 @@ class NoneAlgorithm(Algorithm):
def verify(self, msg, key, sig):
return False
+ @staticmethod
+ def to_jwk(key_obj) -> JWKDict:
+ raise NotImplementedError()
+
+ @staticmethod
+ def from_jwk(jwk: JWKDict):
+ raise NotImplementedError()
+
class HMACAlgorithm(Algorithm):
"""
@@ -299,7 +308,7 @@ if has_crypto:
def to_jwk(key_obj):
obj = None
- if getattr(key_obj, "private_numbers", None):
+ if hasattr(key_obj, "private_numbers"):
# Private key
numbers = key_obj.private_numbers()
@@ -316,7 +325,7 @@ if has_crypto:
"qi": to_base64url_uint(numbers.iqmp).decode(),
}
- elif getattr(key_obj, "verify", None):
+ elif hasattr(key_obj, "verify"):
# Public key
numbers = key_obj.public_numbers()
@@ -587,7 +596,7 @@ if has_crypto:
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg.digest_size,
+ salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
@@ -599,7 +608,7 @@ if has_crypto:
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg.digest_size,
+ salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 5664949..d85f6e8 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import json
import warnings
from calendar import timegm
-from collections.abc import Iterable, Mapping
+from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Type, Union
@@ -47,10 +47,10 @@ class PyJWT:
json_encoder: Optional[Type[json.JSONEncoder]] = None,
sort_headers: bool = True,
) -> str:
- # Check that we get a mapping
- if not isinstance(payload, Mapping):
+ # Check that we get a dict
+ if not isinstance(payload, dict):
raise TypeError(
- "Expecting a mapping object, as JWT only supports "
+ "Expecting a dict object, as JWT only supports "
"JSON objects as payloads."
)
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index 1533b5d..8aa9ad7 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -3,7 +3,7 @@ import json
import pytest
-from jwt.algorithms import Algorithm, HMACAlgorithm, NoneAlgorithm, has_crypto
+from jwt.algorithms import HMACAlgorithm, NoneAlgorithm, has_crypto
from jwt.exceptions import InvalidKeyError
from jwt.utils import base64url_decode
@@ -15,47 +15,23 @@ if has_crypto:
class TestAlgorithms:
- def test_algorithm_should_throw_exception_if_prepare_key_not_impl(self):
- algo = Algorithm()
-
- with pytest.raises(NotImplementedError):
- algo.prepare_key("test")
-
- def test_algorithm_should_throw_exception_if_sign_not_impl(self):
- algo = Algorithm()
-
- with pytest.raises(NotImplementedError):
- algo.sign(b"message", "key")
-
- def test_algorithm_should_throw_exception_if_verify_not_impl(self):
- algo = Algorithm()
-
- with pytest.raises(NotImplementedError):
- algo.verify(b"message", "key", b"signature")
-
- def test_algorithm_should_throw_exception_if_to_jwk_not_impl(self):
- algo = Algorithm()
-
- with pytest.raises(NotImplementedError):
- algo.from_jwk({"val": "ue"})
-
- def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self):
- algo = Algorithm()
+ def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
+ algo = NoneAlgorithm()
- with pytest.raises(NotImplementedError):
- algo.to_jwk("value")
+ with pytest.raises(InvalidKeyError):
+ algo.prepare_key("123")
- def test_algorithm_should_throw_exception_if_compute_hash_digest_not_impl(self):
- algo = Algorithm()
+ def test_none_algorithm_should_throw_exception_on_to_jwk(self):
+ algo = NoneAlgorithm()
with pytest.raises(NotImplementedError):
- algo.compute_hash_digest(b"value")
+ algo.to_jwk("dummy") # Using a dummy argument as is it not relevant
- def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
+ def test_none_algorithm_should_throw_exception_on_from_jwk(self):
algo = NoneAlgorithm()
- with pytest.raises(InvalidKeyError):
- algo.prepare_key("123")
+ with pytest.raises(NotImplementedError):
+ algo.from_jwk({}) # Using a dummy argument as is it not relevant
def test_hmac_should_reject_nonstring_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py
index 2929215..d2aa915 100644
--- a/tests/test_api_jws.py
+++ b/tests/test_api_jws.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pytest
-from jwt.algorithms import Algorithm, has_crypto
+from jwt.algorithms import NoneAlgorithm, has_crypto
from jwt.api_jws import PyJWS
from jwt.exceptions import (
DecodeError,
@@ -39,10 +39,10 @@ def payload():
class TestJWS:
def test_register_algo_does_not_allow_duplicate_registration(self, jws):
- jws.register_algorithm("AAA", Algorithm())
+ jws.register_algorithm("AAA", NoneAlgorithm())
with pytest.raises(ValueError):
- jws.register_algorithm("AAA", Algorithm())
+ jws.register_algorithm("AAA", NoneAlgorithm())
def test_register_algo_rejects_non_algorithm_obj(self, jws):
with pytest.raises(TypeError):