diff options
Diffstat (limited to 'jwt/api_jws.py')
-rw-r--r-- | jwt/api_jws.py | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/jwt/api_jws.py b/jwt/api_jws.py index c914db1..0a775d7 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -3,7 +3,7 @@ from __future__ import annotations import binascii import json import warnings -from typing import Any, Type +from typing import Any, List, Optional, Type from .algorithms import ( Algorithm, @@ -24,7 +24,11 @@ from .warnings import RemovedInPyjwt3Warning class PyJWS: header_typ = "JWT" - def __init__(self, algorithms=None, options=None) -> None: + def __init__( + self, + algorithms: Optional[List[str]] = None, + options: Optional[dict[str, Any]] = None, + ) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( set(algorithms) if algorithms is not None else set(self._algorithms) @@ -164,8 +168,8 @@ class PyJWS: def decode_complete( self, - jwt: str, - key: str = "", + jwt: str | bytes, + key: str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -209,13 +213,13 @@ class PyJWS: def decode( self, - jwt: str, - key: str = "", + jwt: str | bytes, + key: str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, **kwargs, - ) -> str: + ) -> Any: if kwargs: warnings.warn( "passing additional kwargs to decode() is deprecated " @@ -228,7 +232,7 @@ class PyJWS: ) return decoded["payload"] - def get_unverified_header(self, jwt: str | bytes) -> dict: + def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: """Returns back the JWT header parameters as a dict() Note: The signature is not verified so the header parameters @@ -239,7 +243,7 @@ class PyJWS: return headers - def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]: + def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: if isinstance(jwt, str): jwt = jwt.encode("utf-8") @@ -280,13 +284,15 @@ class PyJWS: def _verify_signature( self, signing_input: bytes, - header: dict, + header: dict[str, Any], signature: bytes, - key: str = "", + key: str | bytes = "", algorithms: list[str] | None = None, ) -> None: - - alg = header.get("alg") + try: + alg = header["alg"] + except KeyError: + raise InvalidAlgorithmError("Algorithm not specified") if not alg or (algorithms is not None and alg not in algorithms): raise InvalidAlgorithmError("The specified alg value is not allowed") @@ -304,7 +310,7 @@ class PyJWS: if "kid" in headers: self._validate_kid(headers["kid"]) - def _validate_kid(self, kid: str) -> None: + def _validate_kid(self, kid: Any) -> None: if not isinstance(kid, str): raise InvalidTokenError("Key ID header parameter must be a string") |