summaryrefslogtreecommitdiff
path: root/jwt/api_jws.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/api_jws.py')
-rw-r--r--jwt/api_jws.py34
1 files changed, 20 insertions, 14 deletions
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index c914db1..0a775d7 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import binascii
import json
import warnings
-from typing import Any, Type
+from typing import Any, List, Optional, Type
from .algorithms import (
Algorithm,
@@ -24,7 +24,11 @@ from .warnings import RemovedInPyjwt3Warning
class PyJWS:
header_typ = "JWT"
- def __init__(self, algorithms=None, options=None) -> None:
+ def __init__(
+ self,
+ algorithms: Optional[List[str]] = None,
+ options: Optional[dict[str, Any]] = None,
+ ) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
set(algorithms) if algorithms is not None else set(self._algorithms)
@@ -164,8 +168,8 @@ class PyJWS:
def decode_complete(
self,
- jwt: str,
- key: str = "",
+ jwt: str | bytes,
+ key: str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
@@ -209,13 +213,13 @@ class PyJWS:
def decode(
self,
- jwt: str,
- key: str = "",
+ jwt: str | bytes,
+ key: str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
- ) -> str:
+ ) -> Any:
if kwargs:
warnings.warn(
"passing additional kwargs to decode() is deprecated "
@@ -228,7 +232,7 @@ class PyJWS:
)
return decoded["payload"]
- def get_unverified_header(self, jwt: str | bytes) -> dict:
+ def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
"""Returns back the JWT header parameters as a dict()
Note: The signature is not verified so the header parameters
@@ -239,7 +243,7 @@ class PyJWS:
return headers
- def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
+ def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]:
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")
@@ -280,13 +284,15 @@ class PyJWS:
def _verify_signature(
self,
signing_input: bytes,
- header: dict,
+ header: dict[str, Any],
signature: bytes,
- key: str = "",
+ key: str | bytes = "",
algorithms: list[str] | None = None,
) -> None:
-
- alg = header.get("alg")
+ try:
+ alg = header["alg"]
+ except KeyError:
+ raise InvalidAlgorithmError("Algorithm not specified")
if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed")
@@ -304,7 +310,7 @@ class PyJWS:
if "kid" in headers:
self._validate_kid(headers["kid"])
- def _validate_kid(self, kid: str) -> None:
+ def _validate_kid(self, kid: Any) -> None:
if not isinstance(kid, str):
raise InvalidTokenError("Key ID header parameter must be a string")