summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--jwt/api_jws.py4
-rw-r--r--jwt/api_jwt.py10
-rw-r--r--tests/test_api_jws.py11
3 files changed, 22 insertions, 3 deletions
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index ab8490f..c914db1 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -101,6 +101,7 @@ class PyJWS:
headers: dict[str, Any] | None = None,
json_encoder: Type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
+ sort_headers: bool = True,
) -> str:
segments = []
@@ -133,9 +134,8 @@ class PyJWS:
# True is the standard value for b64, so no need for it
del header["b64"]
- # Fix for headers misorder - issue #715
json_header = json.dumps(
- header, separators=(",", ":"), cls=json_encoder, sort_keys=True
+ header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
).encode()
segments.append(base64url_encode(json_header))
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index e1df3c7..5d21afc 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -45,6 +45,7 @@ class PyJWT:
algorithm: Optional[str] = "HS256",
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
+ sort_headers: bool = True,
) -> str:
# Check that we get a mapping
if not isinstance(payload, Mapping):
@@ -66,7 +67,14 @@ class PyJWT:
json_encoder=json_encoder,
)
- return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)
+ return api_jws.encode(
+ json_payload,
+ key,
+ algorithm,
+ headers,
+ json_encoder,
+ sort_headers=sort_headers,
+ )
def _encode_payload(
self,
diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py
index cfbbe21..fe6c2d4 100644
--- a/tests/test_api_jws.py
+++ b/tests/test_api_jws.py
@@ -414,6 +414,17 @@ class TestJWS:
assert decoded_payload == payload
+ @pytest.mark.parametrize("sort_headers", (False, True))
+ def test_sorting_of_headers(self, jws, payload, sort_headers):
+ jws_message = jws.encode(
+ payload,
+ key="\xc2",
+ headers={"b": "1", "a": "2"},
+ sort_headers=sort_headers,
+ )
+ header_json = base64url_decode(jws_message.split(".")[0])
+ assert sort_headers == (header_json.index(b'"a"') < header_json.index(b'"b"'))
+
def test_decode_invalid_header_padding(self, jws):
example_jws = (
"aeyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"