summaryrefslogtreecommitdiff
path: root/jwt/api_jwt.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/api_jwt.py')
-rw-r--r--jwt/api_jwt.py58
1 files changed, 44 insertions, 14 deletions
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 5d21afc..8840708 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -21,10 +21,10 @@ from .warnings import RemovedInPyjwt3Warning
class PyJWT:
- def __init__(self, options=None):
+ def __init__(self, options: Optional[dict[str, Any]] = None) -> None:
if options is None:
options = {}
- self.options = {**self._get_default_options(), **options}
+ self.options: dict[str, Any] = {**self._get_default_options(), **options}
@staticmethod
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
@@ -96,8 +96,8 @@ class PyJWT:
def decode_complete(
self,
- jwt: str,
- key: str = "",
+ jwt: str | bytes,
+ key: str | bytes = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
# deprecated arg, remove in pyjwt3
@@ -181,8 +181,8 @@ class PyJWT:
def decode(
self,
- jwt: str,
- key: str = "",
+ jwt: str | bytes,
+ key: str | bytes = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
# deprecated arg, remove in pyjwt3
@@ -196,7 +196,7 @@ class PyJWT:
leeway: Union[int, float, timedelta] = 0,
# kwargs
**kwargs,
- ) -> Dict[str, Any]:
+ ) -> Any:
if kwargs:
warnings.warn(
"passing additional kwargs to decode() is deprecated "
@@ -217,7 +217,14 @@ class PyJWT:
)
return decoded["payload"]
- def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0):
+ def _validate_claims(
+ self,
+ payload: dict[str, Any],
+ options: dict[str, Any],
+ audience=None,
+ issuer=None,
+ leeway: float | timedelta = 0,
+ ) -> None:
if isinstance(leeway, timedelta):
leeway = leeway.total_seconds()
@@ -243,12 +250,21 @@ class PyJWT:
if options["verify_aud"]:
self._validate_aud(payload, audience)
- def _validate_required_claims(self, payload, options):
+ def _validate_required_claims(
+ self,
+ payload: dict[str, Any],
+ options: dict[str, Any],
+ ) -> None:
for claim in options["require"]:
if payload.get(claim) is None:
raise MissingRequiredClaimError(claim)
- def _validate_iat(self, payload, now, leeway):
+ def _validate_iat(
+ self,
+ payload: dict[str, Any],
+ now: float,
+ leeway: float,
+ ) -> None:
iat = payload["iat"]
try:
int(iat)
@@ -257,7 +273,12 @@ class PyJWT:
if iat > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (iat)")
- def _validate_nbf(self, payload, now, leeway):
+ def _validate_nbf(
+ self,
+ payload: dict[str, Any],
+ now: float,
+ leeway: float,
+ ) -> None:
try:
nbf = int(payload["nbf"])
except ValueError:
@@ -266,7 +287,12 @@ class PyJWT:
if nbf > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (nbf)")
- def _validate_exp(self, payload, now, leeway):
+ def _validate_exp(
+ self,
+ payload: dict[str, Any],
+ now: float,
+ leeway: float,
+ ) -> None:
try:
exp = int(payload["exp"])
except ValueError:
@@ -275,7 +301,11 @@ class PyJWT:
if exp <= (now - leeway):
raise ExpiredSignatureError("Signature has expired")
- def _validate_aud(self, payload, audience):
+ def _validate_aud(
+ self,
+ payload: dict[str, Any],
+ audience: Optional[Union[str, Iterable[str]]],
+ ) -> None:
if audience is None:
if "aud" not in payload or not payload["aud"]:
return
@@ -303,7 +333,7 @@ class PyJWT:
if all(aud not in audience_claims for aud in audience):
raise InvalidAudienceError("Invalid audience")
- def _validate_iss(self, payload, issuer):
+ def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
if issuer is None:
return