diff options
Diffstat (limited to 'jwt/api_jwt.py')
-rw-r--r-- | jwt/api_jwt.py | 67 |
1 files changed, 35 insertions, 32 deletions
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index d85f6e8..49d1b48 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -5,7 +5,7 @@ import warnings from calendar import timegm from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any from . import api_jws from .exceptions import ( @@ -19,15 +19,18 @@ from .exceptions import ( ) from .warnings import RemovedInPyjwt3Warning +if TYPE_CHECKING: + from .algorithms import AllowedPrivateKeys, AllowedPublicKeys + class PyJWT: - def __init__(self, options: Optional[dict[str, Any]] = None) -> None: + def __init__(self, options: dict[str, Any] | None = None) -> None: if options is None: options = {} self.options: dict[str, Any] = {**self._get_default_options(), **options} @staticmethod - def _get_default_options() -> Dict[str, Union[bool, List[str]]]: + def _get_default_options() -> dict[str, bool | list[str]]: return { "verify_signature": True, "verify_exp": True, @@ -40,11 +43,11 @@ class PyJWT: def encode( self, - payload: Dict[str, Any], - key: str, - algorithm: Optional[str] = "HS256", - headers: Optional[Dict[str, Any]] = None, - json_encoder: Optional[Type[json.JSONEncoder]] = None, + payload: dict[str, Any], + key: AllowedPrivateKeys | str | bytes, + algorithm: str | None = "HS256", + headers: dict[str, Any] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, sort_headers: bool = True, ) -> str: # Check that we get a dict @@ -78,9 +81,9 @@ class PyJWT: def _encode_payload( self, - payload: Dict[str, Any], - headers: Optional[Dict[str, Any]] = None, - json_encoder: Optional[Type[json.JSONEncoder]] = None, + payload: dict[str, Any], + headers: dict[str, Any] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, ) -> bytes: """ Encode a given payload to the bytes to be signed. @@ -97,21 +100,21 @@ class PyJWT: def decode_complete( self, jwt: str | bytes, - key: str | bytes = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict[str, Any]] = None, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 - verify: Optional[bool] = None, + verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 - detached_payload: Optional[bytes] = None, + detached_payload: bytes | None = None, # passthrough arguments to _validate_claims # consider putting in options - audience: Optional[Union[str, Iterable[str]]] = None, - issuer: Optional[str] = None, - leeway: Union[int, float, timedelta] = 0, + audience: str | Iterable[str] | None = None, + issuer: str | None = None, + leeway: float | timedelta = 0, # kwargs - **kwargs, - ) -> Dict[str, Any]: + **kwargs: Any, + ) -> dict[str, Any]: if kwargs: warnings.warn( "passing additional kwargs to decode_complete() is deprecated " @@ -163,7 +166,7 @@ class PyJWT: decoded["payload"] = payload return decoded - def _decode_payload(self, decoded: Dict[str, Any]) -> Any: + def _decode_payload(self, decoded: dict[str, Any]) -> Any: """ Decode the payload from a JWS dictionary (payload, signature, header). @@ -182,20 +185,20 @@ class PyJWT: def decode( self, jwt: str | bytes, - key: str | bytes = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict[str, Any]] = None, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 - verify: Optional[bool] = None, + verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 - detached_payload: Optional[bytes] = None, + detached_payload: bytes | None = None, # passthrough arguments to _validate_claims # consider putting in options - audience: Optional[Union[str, Iterable[str]]] = None, - issuer: Optional[str] = None, - leeway: Union[int, float, timedelta] = 0, + audience: str | Iterable[str] | None = None, + issuer: str | None = None, + leeway: float | timedelta = 0, # kwargs - **kwargs, + **kwargs: Any, ) -> Any: if kwargs: warnings.warn( @@ -303,7 +306,7 @@ class PyJWT: def _validate_aud( self, payload: dict[str, Any], - audience: Optional[Union[str, Iterable[str]]], + audience: str | Iterable[str] | None, ) -> None: if audience is None: if "aud" not in payload or not payload["aud"]: |