summaryrefslogtreecommitdiff
path: root/jwt/api_jws.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/api_jws.py')
-rw-r--r--jwt/api_jws.py23
1 files changed, 13 insertions, 10 deletions
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index 0a775d7..fa6708c 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import binascii
import json
import warnings
-from typing import Any, List, Optional, Type
+from typing import TYPE_CHECKING, Any
from .algorithms import (
Algorithm,
@@ -20,14 +20,17 @@ from .exceptions import (
from .utils import base64url_decode, base64url_encode
from .warnings import RemovedInPyjwt3Warning
+if TYPE_CHECKING:
+ from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
+
class PyJWS:
header_typ = "JWT"
def __init__(
self,
- algorithms: Optional[List[str]] = None,
- options: Optional[dict[str, Any]] = None,
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
@@ -100,10 +103,10 @@ class PyJWS:
def encode(
self,
payload: bytes,
- key: str,
+ key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
- json_encoder: Type[json.JSONEncoder] | None = None,
+ json_encoder: type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
sort_headers: bool = True,
) -> str:
@@ -169,7 +172,7 @@ class PyJWS:
def decode_complete(
self,
jwt: str | bytes,
- key: str | bytes = "",
+ key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
@@ -214,7 +217,7 @@ class PyJWS:
def decode(
self,
jwt: str | bytes,
- key: str | bytes = "",
+ key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
@@ -286,7 +289,7 @@ class PyJWS:
signing_input: bytes,
header: dict[str, Any],
signature: bytes,
- key: str | bytes = "",
+ key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
) -> None:
try:
@@ -301,9 +304,9 @@ class PyJWS:
alg_obj = self.get_algorithm_by_name(alg)
except NotImplementedError as e:
raise InvalidAlgorithmError("Algorithm not supported") from e
- key = alg_obj.prepare_key(key)
+ prepared_key = alg_obj.prepare_key(key)
- if not alg_obj.verify(signing_input, key, signature):
+ if not alg_obj.verify(signing_input, prepared_key, signature):
raise InvalidSignatureError("Signature verification failed")
def _validate_headers(self, headers: dict[str, Any]) -> None: