diff options
Diffstat (limited to 'jwt')
-rw-r--r-- | jwt/algorithms.py | 29 | ||||
-rw-r--r-- | jwt/api_jwt.py | 8 |
2 files changed, 23 insertions, 14 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 277f940..bc928fe 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,6 +1,7 @@ import hashlib import hmac import json +from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, Type, Union from .exceptions import InvalidKeyError @@ -117,7 +118,7 @@ def get_default_algorithms() -> Dict[str, "Algorithm"]: return default_algorithms -class Algorithm: +class Algorithm(ABC): """ The interface for an algorithm used to sign and verify tokens. """ @@ -148,40 +149,40 @@ class Algorithm: # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605 # that may still be poorly supported. + @abstractmethod def prepare_key(self, key: Any) -> Any: """ Performs necessary validation and conversions on the key and returns the key value in the proper format for sign() and verify(). """ - raise NotImplementedError + @abstractmethod def sign(self, msg: bytes, key: Any) -> bytes: """ Returns a digital signature for the specified message using the specified key value. """ - raise NotImplementedError + @abstractmethod def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: """ Verifies that the specified digital signature is valid for the specified message and key values. """ - raise NotImplementedError @staticmethod + @abstractmethod def to_jwk(key_obj) -> JWKDict: """ Serializes a given RSA key into a JWK """ - raise NotImplementedError @staticmethod + @abstractmethod def from_jwk(jwk: JWKDict): """ Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object """ - raise NotImplementedError class NoneAlgorithm(Algorithm): @@ -205,6 +206,14 @@ class NoneAlgorithm(Algorithm): def verify(self, msg, key, sig): return False + @staticmethod + def to_jwk(key_obj) -> JWKDict: + raise NotImplementedError() + + @staticmethod + def from_jwk(jwk: JWKDict): + raise NotImplementedError() + class HMACAlgorithm(Algorithm): """ @@ -299,7 +308,7 @@ if has_crypto: def to_jwk(key_obj): obj = None - if getattr(key_obj, "private_numbers", None): + if hasattr(key_obj, "private_numbers"): # Private key numbers = key_obj.private_numbers() @@ -316,7 +325,7 @@ if has_crypto: "qi": to_base64url_uint(numbers.iqmp).decode(), } - elif getattr(key_obj, "verify", None): + elif hasattr(key_obj, "verify"): # Public key numbers = key_obj.public_numbers() @@ -587,7 +596,7 @@ if has_crypto: msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size, + salt_length=self.hash_alg().digest_size, ), self.hash_alg(), ) @@ -599,7 +608,7 @@ if has_crypto: msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size, + salt_length=self.hash_alg().digest_size, ), self.hash_alg(), ) diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 5664949..d85f6e8 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import warnings from calendar import timegm -from collections.abc import Iterable, Mapping +from collections.abc import Iterable from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional, Type, Union @@ -47,10 +47,10 @@ class PyJWT: json_encoder: Optional[Type[json.JSONEncoder]] = None, sort_headers: bool = True, ) -> str: - # Check that we get a mapping - if not isinstance(payload, Mapping): + # Check that we get a dict + if not isinstance(payload, dict): raise TypeError( - "Expecting a mapping object, as JWT only supports " + "Expecting a dict object, as JWT only supports " "JSON objects as payloads." ) |