diff options
-rw-r--r-- | jwt/api_jws.py | 4 | ||||
-rw-r--r-- | jwt/api_jwt.py | 10 | ||||
-rw-r--r-- | tests/test_api_jws.py | 11 |
3 files changed, 22 insertions, 3 deletions
diff --git a/jwt/api_jws.py b/jwt/api_jws.py index ab8490f..c914db1 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -101,6 +101,7 @@ class PyJWS: headers: dict[str, Any] | None = None, json_encoder: Type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, + sort_headers: bool = True, ) -> str: segments = [] @@ -133,9 +134,8 @@ class PyJWS: # True is the standard value for b64, so no need for it del header["b64"] - # Fix for headers misorder - issue #715 json_header = json.dumps( - header, separators=(",", ":"), cls=json_encoder, sort_keys=True + header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers ).encode() segments.append(base64url_encode(json_header)) diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index e1df3c7..5d21afc 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -45,6 +45,7 @@ class PyJWT: algorithm: Optional[str] = "HS256", headers: Optional[Dict[str, Any]] = None, json_encoder: Optional[Type[json.JSONEncoder]] = None, + sort_headers: bool = True, ) -> str: # Check that we get a mapping if not isinstance(payload, Mapping): @@ -66,7 +67,14 @@ class PyJWT: json_encoder=json_encoder, ) - return api_jws.encode(json_payload, key, algorithm, headers, json_encoder) + return api_jws.encode( + json_payload, + key, + algorithm, + headers, + json_encoder, + sort_headers=sort_headers, + ) def _encode_payload( self, diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index cfbbe21..fe6c2d4 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -414,6 +414,17 @@ class TestJWS: assert decoded_payload == payload + @pytest.mark.parametrize("sort_headers", (False, True)) + def test_sorting_of_headers(self, jws, payload, sort_headers): + jws_message = jws.encode( + payload, + key="\xc2", + headers={"b": "1", "a": "2"}, + sort_headers=sort_headers, + ) + header_json = base64url_decode(jws_message.split(".")[0]) + assert sort_headers == (header_json.index(b'"a"') < header_json.index(b'"b"')) + def test_decode_invalid_header_padding(self, jws): example_jws = ( "aeyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" |