diff options
Diffstat (limited to 'jwt/api_jws.py')
-rw-r--r-- | jwt/api_jws.py | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 0a775d7..fa6708c 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -3,7 +3,7 @@ from __future__ import annotations import binascii import json import warnings -from typing import Any, List, Optional, Type +from typing import TYPE_CHECKING, Any from .algorithms import ( Algorithm, @@ -20,14 +20,17 @@ from .exceptions import ( from .utils import base64url_decode, base64url_encode from .warnings import RemovedInPyjwt3Warning +if TYPE_CHECKING: + from .algorithms import AllowedPrivateKeys, AllowedPublicKeys + class PyJWS: header_typ = "JWT" def __init__( self, - algorithms: Optional[List[str]] = None, - options: Optional[dict[str, Any]] = None, + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, ) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( @@ -100,10 +103,10 @@ class PyJWS: def encode( self, payload: bytes, - key: str, + key: AllowedPrivateKeys | str | bytes, algorithm: str | None = "HS256", headers: dict[str, Any] | None = None, - json_encoder: Type[json.JSONEncoder] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, sort_headers: bool = True, ) -> str: @@ -169,7 +172,7 @@ class PyJWS: def decode_complete( self, jwt: str | bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -214,7 +217,7 @@ class PyJWS: def decode( self, jwt: str | bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -286,7 +289,7 @@ class PyJWS: signing_input: bytes, header: dict[str, Any], signature: bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, ) -> None: try: @@ -301,9 +304,9 @@ class PyJWS: alg_obj = self.get_algorithm_by_name(alg) except NotImplementedError as e: raise InvalidAlgorithmError("Algorithm not supported") from e - key = alg_obj.prepare_key(key) + prepared_key = alg_obj.prepare_key(key) - if not alg_obj.verify(signing_input, key, signature): + if not alg_obj.verify(signing_input, prepared_key, signature): raise InvalidSignatureError("Signature verification failed") def _validate_headers(self, headers: dict[str, Any]) -> None: |