From 2fc6aa3ffecb54325d44ef076d4eff7418e97a81 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 8 Dec 2022 11:51:24 +0200 Subject: Add PyJWT._{de,en}code_payload hooks (#829) * Add PyJWT._decode_payload hook * Add PyJWT._encode_payload hook --- jwt/api_jwt.py | 49 ++++++++++++++++++++++++++++++++++++-------- tests/test_compressed_jwt.py | 37 +++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 9 deletions(-) create mode 100644 tests/test_compressed_jwt.py diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 77857d2..e1df3c7 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -60,12 +60,32 @@ class PyJWT: if isinstance(payload.get(time_claim), datetime): payload[time_claim] = timegm(payload[time_claim].utctimetuple()) - json_payload = json.dumps( - payload, separators=(",", ":"), cls=json_encoder - ).encode("utf-8") + json_payload = self._encode_payload( + payload, + headers=headers, + json_encoder=json_encoder, + ) return api_jws.encode(json_payload, key, algorithm, headers, json_encoder) + def _encode_payload( + self, + payload: Dict[str, Any], + headers: Optional[Dict[str, Any]] = None, + json_encoder: Optional[Type[json.JSONEncoder]] = None, + ) -> bytes: + """ + Encode a given payload to the bytes to be signed. + + This method is intended to be overridden by subclasses that need to + encode the payload in a different way, e.g. compress the payload. + """ + return json.dumps( + payload, + separators=(",", ":"), + cls=json_encoder, + ).encode("utf-8") + def decode_complete( self, jwt: str, @@ -125,12 +145,7 @@ class PyJWT: detached_payload=detached_payload, ) - try: - payload = json.loads(decoded["payload"]) - except ValueError as e: - raise DecodeError(f"Invalid payload string: {e}") - if not isinstance(payload, dict): - raise DecodeError("Invalid payload string: must be a json object") + payload = self._decode_payload(decoded) merged_options = {**self.options, **options} self._validate_claims( @@ -140,6 +155,22 @@ class PyJWT: decoded["payload"] = payload return decoded + def _decode_payload(self, decoded: Dict[str, Any]) -> Any: + """ + Decode the payload from a JWS dictionary (payload, signature, header). + + This method is intended to be overridden by subclasses that need to + decode the payload in a different way, e.g. decompress compressed + payloads. + """ + try: + payload = json.loads(decoded["payload"]) + except ValueError as e: + raise DecodeError(f"Invalid payload string: {e}") + if not isinstance(payload, dict): + raise DecodeError("Invalid payload string: must be a json object") + return payload + def decode( self, jwt: str, diff --git a/tests/test_compressed_jwt.py b/tests/test_compressed_jwt.py new file mode 100644 index 0000000..21fac3f --- /dev/null +++ b/tests/test_compressed_jwt.py @@ -0,0 +1,37 @@ +import json +import zlib + +from jwt import PyJWT + + +class CompressedPyJWT(PyJWT): + def _decode_payload(self, decoded): + return json.loads( + # wbits=-15 has zlib not worry about headers of crc's + zlib.decompress(decoded["payload"], wbits=-15).decode("utf-8") + ) + + +def test_decodes_complete_valid_jwt_with_compressed_payload(): + # Test case from https://github.com/jpadilla/pyjwt/pull/753/files + example_payload = {"hello": "world"} + example_secret = "secret" + # payload made with the pako (https://nodeca.github.io/pako/) library in Javascript: + # Buffer.from(pako.deflateRaw('{"hello": "world"}')).toString('base64') + example_jwt = ( + b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" + b".q1bKSM3JyVeyUlAqzy/KSVGqBQA=" + b".08wHYeuh1rJXmcBcMrz6NxmbxAnCQp2rGTKfRNIkxiw=" + ) + decoded = CompressedPyJWT().decode_complete( + example_jwt, example_secret, algorithms=["HS256"] + ) + + assert decoded == { + "header": {"alg": "HS256", "typ": "JWT"}, + "payload": example_payload, + "signature": ( + b"\xd3\xcc\x07a\xeb\xa1\xd6\xb2W\x99\xc0\\2\xbc\xfa7" + b"\x19\x9b\xc4\t\xc2B\x9d\xab\x192\x9fD\xd2$\xc6," + ), + } -- cgit v1.2.1