summaryrefslogtreecommitdiff
path: root/jwt/api_jws.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/api_jws.py')
-rw-r--r--jwt/api_jws.py61
1 files changed, 31 insertions, 30 deletions
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index 90206c9..ab8490f 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -1,8 +1,9 @@
+from __future__ import annotations
+
import binascii
import json
import warnings
-from collections.abc import Mapping
-from typing import Any, Dict, List, Optional, Type
+from typing import Any, Type
from .algorithms import (
Algorithm,
@@ -23,7 +24,7 @@ from .warnings import RemovedInPyjwt3Warning
class PyJWS:
header_typ = "JWT"
- def __init__(self, algorithms=None, options=None):
+ def __init__(self, algorithms=None, options=None) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
set(algorithms) if algorithms is not None else set(self._algorithms)
@@ -39,10 +40,10 @@ class PyJWS:
self.options = {**self._get_default_options(), **options}
@staticmethod
- def _get_default_options():
+ def _get_default_options() -> dict[str, bool]:
return {"verify_signature": True}
- def register_algorithm(self, alg_id, alg_obj):
+ def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
"""
Registers a new Algorithm for use when creating and verifying tokens.
"""
@@ -55,7 +56,7 @@ class PyJWS:
self._algorithms[alg_id] = alg_obj
self._valid_algs.add(alg_id)
- def unregister_algorithm(self, alg_id):
+ def unregister_algorithm(self, alg_id: str) -> None:
"""
Unregisters an Algorithm for use when creating and verifying tokens
Throws KeyError if algorithm is not registered.
@@ -69,7 +70,7 @@ class PyJWS:
del self._algorithms[alg_id]
self._valid_algs.remove(alg_id)
- def get_algorithms(self):
+ def get_algorithms(self) -> list[str]:
"""
Returns a list of supported values for the 'alg' parameter.
"""
@@ -96,9 +97,9 @@ class PyJWS:
self,
payload: bytes,
key: str,
- algorithm: Optional[str] = "HS256",
- headers: Optional[Dict[str, Any]] = None,
- json_encoder: Optional[Type[json.JSONEncoder]] = None,
+ algorithm: str | None = "HS256",
+ headers: dict[str, Any] | None = None,
+ json_encoder: Type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
) -> str:
segments = []
@@ -117,7 +118,7 @@ class PyJWS:
is_payload_detached = True
# Header
- header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any]
+ header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}
if headers:
self._validate_headers(headers)
@@ -165,11 +166,11 @@ class PyJWS:
self,
jwt: str,
key: str = "",
- algorithms: Optional[List[str]] = None,
- options: Optional[Dict[str, Any]] = None,
- detached_payload: Optional[bytes] = None,
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
+ detached_payload: bytes | None = None,
**kwargs,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
@@ -210,9 +211,9 @@ class PyJWS:
self,
jwt: str,
key: str = "",
- algorithms: Optional[List[str]] = None,
- options: Optional[Dict[str, Any]] = None,
- detached_payload: Optional[bytes] = None,
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
+ detached_payload: bytes | None = None,
**kwargs,
) -> str:
if kwargs:
@@ -227,7 +228,7 @@ class PyJWS:
)
return decoded["payload"]
- def get_unverified_header(self, jwt):
+ def get_unverified_header(self, jwt: str | bytes) -> dict:
"""Returns back the JWT header parameters as a dict()
Note: The signature is not verified so the header parameters
@@ -238,7 +239,7 @@ class PyJWS:
return headers
- def _load(self, jwt):
+ def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")
@@ -261,7 +262,7 @@ class PyJWS:
except ValueError as e:
raise DecodeError(f"Invalid header string: {e}") from e
- if not isinstance(header, Mapping):
+ if not isinstance(header, dict):
raise DecodeError("Invalid header string: must be a json object")
try:
@@ -278,16 +279,16 @@ class PyJWS:
def _verify_signature(
self,
- signing_input,
- header,
- signature,
- key="",
- algorithms=None,
- ):
+ signing_input: bytes,
+ header: dict,
+ signature: bytes,
+ key: str = "",
+ algorithms: list[str] | None = None,
+ ) -> None:
alg = header.get("alg")
- if algorithms is not None and alg not in algorithms:
+ if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed")
try:
@@ -299,11 +300,11 @@ class PyJWS:
if not alg_obj.verify(signing_input, key, signature):
raise InvalidSignatureError("Signature verification failed")
- def _validate_headers(self, headers):
+ def _validate_headers(self, headers: dict[str, Any]) -> None:
if "kid" in headers:
self._validate_kid(headers["kid"])
- def _validate_kid(self, kid):
+ def _validate_kid(self, kid: str) -> None:
if not isinstance(kid, str):
raise InvalidTokenError("Key ID header parameter must be a string")