summaryrefslogtreecommitdiff
path: root/jwt/api_jwt.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/api_jwt.py')
-rw-r--r--jwt/api_jwt.py67
1 files changed, 35 insertions, 32 deletions
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index d85f6e8..49d1b48 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -5,7 +5,7 @@ import warnings
from calendar import timegm
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
-from typing import Any, Dict, List, Optional, Type, Union
+from typing import TYPE_CHECKING, Any
from . import api_jws
from .exceptions import (
@@ -19,15 +19,18 @@ from .exceptions import (
)
from .warnings import RemovedInPyjwt3Warning
+if TYPE_CHECKING:
+ from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
+
class PyJWT:
- def __init__(self, options: Optional[dict[str, Any]] = None) -> None:
+ def __init__(self, options: dict[str, Any] | None = None) -> None:
if options is None:
options = {}
self.options: dict[str, Any] = {**self._get_default_options(), **options}
@staticmethod
- def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
+ def _get_default_options() -> dict[str, bool | list[str]]:
return {
"verify_signature": True,
"verify_exp": True,
@@ -40,11 +43,11 @@ class PyJWT:
def encode(
self,
- payload: Dict[str, Any],
- key: str,
- algorithm: Optional[str] = "HS256",
- headers: Optional[Dict[str, Any]] = None,
- json_encoder: Optional[Type[json.JSONEncoder]] = None,
+ payload: dict[str, Any],
+ key: AllowedPrivateKeys | str | bytes,
+ algorithm: str | None = "HS256",
+ headers: dict[str, Any] | None = None,
+ json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
) -> str:
# Check that we get a dict
@@ -78,9 +81,9 @@ class PyJWT:
def _encode_payload(
self,
- payload: Dict[str, Any],
- headers: Optional[Dict[str, Any]] = None,
- json_encoder: Optional[Type[json.JSONEncoder]] = None,
+ payload: dict[str, Any],
+ headers: dict[str, Any] | None = None,
+ json_encoder: type[json.JSONEncoder] | None = None,
) -> bytes:
"""
Encode a given payload to the bytes to be signed.
@@ -97,21 +100,21 @@ class PyJWT:
def decode_complete(
self,
jwt: str | bytes,
- key: str | bytes = "",
- algorithms: Optional[List[str]] = None,
- options: Optional[Dict[str, Any]] = None,
+ key: AllowedPublicKeys | str | bytes = "",
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
- verify: Optional[bool] = None,
+ verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
- detached_payload: Optional[bytes] = None,
+ detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims
# consider putting in options
- audience: Optional[Union[str, Iterable[str]]] = None,
- issuer: Optional[str] = None,
- leeway: Union[int, float, timedelta] = 0,
+ audience: str | Iterable[str] | None = None,
+ issuer: str | None = None,
+ leeway: float | timedelta = 0,
# kwargs
- **kwargs,
- ) -> Dict[str, Any]:
+ **kwargs: Any,
+ ) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
@@ -163,7 +166,7 @@ class PyJWT:
decoded["payload"] = payload
return decoded
- def _decode_payload(self, decoded: Dict[str, Any]) -> Any:
+ def _decode_payload(self, decoded: dict[str, Any]) -> Any:
"""
Decode the payload from a JWS dictionary (payload, signature, header).
@@ -182,20 +185,20 @@ class PyJWT:
def decode(
self,
jwt: str | bytes,
- key: str | bytes = "",
- algorithms: Optional[List[str]] = None,
- options: Optional[Dict[str, Any]] = None,
+ key: AllowedPublicKeys | str | bytes = "",
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
- verify: Optional[bool] = None,
+ verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
- detached_payload: Optional[bytes] = None,
+ detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims
# consider putting in options
- audience: Optional[Union[str, Iterable[str]]] = None,
- issuer: Optional[str] = None,
- leeway: Union[int, float, timedelta] = 0,
+ audience: str | Iterable[str] | None = None,
+ issuer: str | None = None,
+ leeway: float | timedelta = 0,
# kwargs
- **kwargs,
+ **kwargs: Any,
) -> Any:
if kwargs:
warnings.warn(
@@ -303,7 +306,7 @@ class PyJWT:
def _validate_aud(
self,
payload: dict[str, Any],
- audience: Optional[Union[str, Iterable[str]]],
+ audience: str | Iterable[str] | None,
) -> None:
if audience is None:
if "aud" not in payload or not payload["aud"]: