From fb9b3119449353dca6c03be27f32e482852473c3 Mon Sep 17 00:00:00 2001 From: Erik Vroon Date: Thu, 8 Dec 2022 16:50:04 +0100 Subject: Add `sort_headers` parameter to `api_jwt.encode` (#832) * Add `sort_headers` parameter to `api_jwt.encode` This allows you to not sort headers, which prevents a breaking change between v2.4.0 and v2.5.0 * Add `test_sorting_headers` test * Remove outdated comment about misordered headers * Explicity assert sorting in `test_sorting_of_headers` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Parametrize `test_sorting_of_headers` * Use normal dict in `test_sorting_of_headers` * fixup! Use normal dict in `test_sorting_of_headers` Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- jwt/api_jws.py | 4 ++-- jwt/api_jwt.py | 10 +++++++++- tests/test_api_jws.py | 11 +++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/jwt/api_jws.py b/jwt/api_jws.py index ab8490f..c914db1 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -101,6 +101,7 @@ class PyJWS: headers: dict[str, Any] | None = None, json_encoder: Type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, + sort_headers: bool = True, ) -> str: segments = [] @@ -133,9 +134,8 @@ class PyJWS: # True is the standard value for b64, so no need for it del header["b64"] - # Fix for headers misorder - issue #715 json_header = json.dumps( - header, separators=(",", ":"), cls=json_encoder, sort_keys=True + header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers ).encode() segments.append(base64url_encode(json_header)) diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index e1df3c7..5d21afc 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -45,6 +45,7 @@ class PyJWT: algorithm: Optional[str] = "HS256", headers: Optional[Dict[str, Any]] = None, json_encoder: Optional[Type[json.JSONEncoder]] = None, + sort_headers: bool = True, ) -> str: # Check that we get a mapping if not isinstance(payload, Mapping): @@ -66,7 +67,14 @@ class PyJWT: json_encoder=json_encoder, ) - return api_jws.encode(json_payload, key, algorithm, headers, json_encoder) + return api_jws.encode( + json_payload, + key, + algorithm, + headers, + json_encoder, + sort_headers=sort_headers, + ) def _encode_payload( self, diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index cfbbe21..fe6c2d4 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -414,6 +414,17 @@ class TestJWS: assert decoded_payload == payload + @pytest.mark.parametrize("sort_headers", (False, True)) + def test_sorting_of_headers(self, jws, payload, sort_headers): + jws_message = jws.encode( + payload, + key="\xc2", + headers={"b": "1", "a": "2"}, + sort_headers=sort_headers, + ) + header_json = base64url_decode(jws_message.split(".")[0]) + assert sort_headers == (header_json.index(b'"a"') < header_json.index(b'"b"')) + def test_decode_invalid_header_padding(self, jws): example_jws = ( "aeyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" -- cgit v1.2.1