summaryrefslogtreecommitdiff
path: root/jwt/api_jwt.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/api_jwt.py')
-rw-r--r--jwt/api_jwt.py49
1 files changed, 40 insertions, 9 deletions
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 77857d2..e1df3c7 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -60,12 +60,32 @@ class PyJWT:
if isinstance(payload.get(time_claim), datetime):
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
- json_payload = json.dumps(
- payload, separators=(",", ":"), cls=json_encoder
- ).encode("utf-8")
+ json_payload = self._encode_payload(
+ payload,
+ headers=headers,
+ json_encoder=json_encoder,
+ )
return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)
+ def _encode_payload(
+ self,
+ payload: Dict[str, Any],
+ headers: Optional[Dict[str, Any]] = None,
+ json_encoder: Optional[Type[json.JSONEncoder]] = None,
+ ) -> bytes:
+ """
+ Encode a given payload to the bytes to be signed.
+
+ This method is intended to be overridden by subclasses that need to
+ encode the payload in a different way, e.g. compress the payload.
+ """
+ return json.dumps(
+ payload,
+ separators=(",", ":"),
+ cls=json_encoder,
+ ).encode("utf-8")
+
def decode_complete(
self,
jwt: str,
@@ -125,12 +145,7 @@ class PyJWT:
detached_payload=detached_payload,
)
- try:
- payload = json.loads(decoded["payload"])
- except ValueError as e:
- raise DecodeError(f"Invalid payload string: {e}")
- if not isinstance(payload, dict):
- raise DecodeError("Invalid payload string: must be a json object")
+ payload = self._decode_payload(decoded)
merged_options = {**self.options, **options}
self._validate_claims(
@@ -140,6 +155,22 @@ class PyJWT:
decoded["payload"] = payload
return decoded
+ def _decode_payload(self, decoded: Dict[str, Any]) -> Any:
+ """
+ Decode the payload from a JWS dictionary (payload, signature, header).
+
+ This method is intended to be overridden by subclasses that need to
+ decode the payload in a different way, e.g. decompress compressed
+ payloads.
+ """
+ try:
+ payload = json.loads(decoded["payload"])
+ except ValueError as e:
+ raise DecodeError(f"Invalid payload string: {e}")
+ if not isinstance(payload, dict):
+ raise DecodeError("Invalid payload string: must be a json object")
+ return payload
+
def decode(
self,
jwt: str,