From 777efa2f51249f63b0f95804230117723eca5d09 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Tue, 7 Mar 2023 03:22:39 +0100 Subject: Make `Algorithm` an abstract base class (#845) * Make `Algorithm` an abstract base class This also removes some tests that are not relevant anymore Raise `NotImplementedError` for `NoneAlgorithm` * Use `hasattr` instead of `getattr` * Only allow `dict` in `encode` --- jwt/algorithms.py | 29 +++++++++++++++++++---------- jwt/api_jwt.py | 8 ++++---- tests/test_algorithms.py | 46 +++++++++++----------------------------------- tests/test_api_jws.py | 6 +++--- 4 files changed, 37 insertions(+), 52 deletions(-) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 277f940..bc928fe 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,6 +1,7 @@ import hashlib import hmac import json +from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, Type, Union from .exceptions import InvalidKeyError @@ -117,7 +118,7 @@ def get_default_algorithms() -> Dict[str, "Algorithm"]: return default_algorithms -class Algorithm: +class Algorithm(ABC): """ The interface for an algorithm used to sign and verify tokens. """ @@ -148,40 +149,40 @@ class Algorithm: # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605 # that may still be poorly supported. + @abstractmethod def prepare_key(self, key: Any) -> Any: """ Performs necessary validation and conversions on the key and returns the key value in the proper format for sign() and verify(). """ - raise NotImplementedError + @abstractmethod def sign(self, msg: bytes, key: Any) -> bytes: """ Returns a digital signature for the specified message using the specified key value. """ - raise NotImplementedError + @abstractmethod def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: """ Verifies that the specified digital signature is valid for the specified message and key values. """ - raise NotImplementedError @staticmethod + @abstractmethod def to_jwk(key_obj) -> JWKDict: """ Serializes a given RSA key into a JWK """ - raise NotImplementedError @staticmethod + @abstractmethod def from_jwk(jwk: JWKDict): """ Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object """ - raise NotImplementedError class NoneAlgorithm(Algorithm): @@ -205,6 +206,14 @@ class NoneAlgorithm(Algorithm): def verify(self, msg, key, sig): return False + @staticmethod + def to_jwk(key_obj) -> JWKDict: + raise NotImplementedError() + + @staticmethod + def from_jwk(jwk: JWKDict): + raise NotImplementedError() + class HMACAlgorithm(Algorithm): """ @@ -299,7 +308,7 @@ if has_crypto: def to_jwk(key_obj): obj = None - if getattr(key_obj, "private_numbers", None): + if hasattr(key_obj, "private_numbers"): # Private key numbers = key_obj.private_numbers() @@ -316,7 +325,7 @@ if has_crypto: "qi": to_base64url_uint(numbers.iqmp).decode(), } - elif getattr(key_obj, "verify", None): + elif hasattr(key_obj, "verify"): # Public key numbers = key_obj.public_numbers() @@ -587,7 +596,7 @@ if has_crypto: msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size, + salt_length=self.hash_alg().digest_size, ), self.hash_alg(), ) @@ -599,7 +608,7 @@ if has_crypto: msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size, + salt_length=self.hash_alg().digest_size, ), self.hash_alg(), ) diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 5664949..d85f6e8 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import warnings from calendar import timegm -from collections.abc import Iterable, Mapping +from collections.abc import Iterable from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional, Type, Union @@ -47,10 +47,10 @@ class PyJWT: json_encoder: Optional[Type[json.JSONEncoder]] = None, sort_headers: bool = True, ) -> str: - # Check that we get a mapping - if not isinstance(payload, Mapping): + # Check that we get a dict + if not isinstance(payload, dict): raise TypeError( - "Expecting a mapping object, as JWT only supports " + "Expecting a dict object, as JWT only supports " "JSON objects as payloads." ) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 1533b5d..8aa9ad7 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -3,7 +3,7 @@ import json import pytest -from jwt.algorithms import Algorithm, HMACAlgorithm, NoneAlgorithm, has_crypto +from jwt.algorithms import HMACAlgorithm, NoneAlgorithm, has_crypto from jwt.exceptions import InvalidKeyError from jwt.utils import base64url_decode @@ -15,47 +15,23 @@ if has_crypto: class TestAlgorithms: - def test_algorithm_should_throw_exception_if_prepare_key_not_impl(self): - algo = Algorithm() - - with pytest.raises(NotImplementedError): - algo.prepare_key("test") - - def test_algorithm_should_throw_exception_if_sign_not_impl(self): - algo = Algorithm() - - with pytest.raises(NotImplementedError): - algo.sign(b"message", "key") - - def test_algorithm_should_throw_exception_if_verify_not_impl(self): - algo = Algorithm() - - with pytest.raises(NotImplementedError): - algo.verify(b"message", "key", b"signature") - - def test_algorithm_should_throw_exception_if_to_jwk_not_impl(self): - algo = Algorithm() - - with pytest.raises(NotImplementedError): - algo.from_jwk({"val": "ue"}) - - def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self): - algo = Algorithm() + def test_none_algorithm_should_throw_exception_if_key_is_not_none(self): + algo = NoneAlgorithm() - with pytest.raises(NotImplementedError): - algo.to_jwk("value") + with pytest.raises(InvalidKeyError): + algo.prepare_key("123") - def test_algorithm_should_throw_exception_if_compute_hash_digest_not_impl(self): - algo = Algorithm() + def test_none_algorithm_should_throw_exception_on_to_jwk(self): + algo = NoneAlgorithm() with pytest.raises(NotImplementedError): - algo.compute_hash_digest(b"value") + algo.to_jwk("dummy") # Using a dummy argument as is it not relevant - def test_none_algorithm_should_throw_exception_if_key_is_not_none(self): + def test_none_algorithm_should_throw_exception_on_from_jwk(self): algo = NoneAlgorithm() - with pytest.raises(InvalidKeyError): - algo.prepare_key("123") + with pytest.raises(NotImplementedError): + algo.from_jwk({}) # Using a dummy argument as is it not relevant def test_hmac_should_reject_nonstring_key(self): algo = HMACAlgorithm(HMACAlgorithm.SHA256) diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index 2929215..d2aa915 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -3,7 +3,7 @@ from decimal import Decimal import pytest -from jwt.algorithms import Algorithm, has_crypto +from jwt.algorithms import NoneAlgorithm, has_crypto from jwt.api_jws import PyJWS from jwt.exceptions import ( DecodeError, @@ -39,10 +39,10 @@ def payload(): class TestJWS: def test_register_algo_does_not_allow_duplicate_registration(self, jws): - jws.register_algorithm("AAA", Algorithm()) + jws.register_algorithm("AAA", NoneAlgorithm()) with pytest.raises(ValueError): - jws.register_algorithm("AAA", Algorithm()) + jws.register_algorithm("AAA", NoneAlgorithm()) def test_register_algo_rejects_non_algorithm_obj(self, jws): with pytest.raises(TypeError): -- cgit v1.2.1