summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAjitomi, Daisuke <ajitomi@gmail.com>2021-08-06 11:34:24 +0900
committerGitHub <noreply@github.com>2021-08-05 22:34:24 -0400
commit8de44287752f9d2c0112bb3062389462d56ead86 (patch)
treeadb964790095647e9b31737cd38d9f278011903c
parentcfe52613ca0fad1535bcf5b5753aca5760043cae (diff)
downloadpyjwt-8de44287752f9d2c0112bb3062389462d56ead86.tar.gz
Prefer headers['alg'] to algorithm parameter in encode(). (#673)
* Prefer headers['alg'] to algorithm parameter in encode(). * Fix lack of @crypto_required. * Prefer headers['alg'] to algorithm parameter in encode(). * Prefer headers['alg'] to algorithm parameter in encode(). * Make algorithm parameter of encode() Optioanl explicitly.
-rw-r--r--CHANGELOG.rst1
-rw-r--r--docs/api.rst5
-rw-r--r--jwt/api_jws.py6
-rw-r--r--jwt/api_jwt.py2
-rw-r--r--tests/test_api_jws.py26
5 files changed, 36 insertions, 4 deletions
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 763822a..b049e9e 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -13,6 +13,7 @@ Changed
Fixed
~~~~~
+- Prefer `headers["alg"]` to `algorithm` in `jwt.encode()`. `#673 <https://github.com/jpadilla/pyjwt/pull/673>`__
- Fix aud validation to support {'aud': null} case. `#670 <https://github.com/jpadilla/pyjwt/pull/670>`__
Added
diff --git a/docs/api.rst b/docs/api.rst
index d23d4bd..f1720b9 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -13,8 +13,9 @@ API Reference
* for **asymmetric algorithms**: PEM-formatted private key, a multiline string
* for **symmetric algorithms**: plain string, sufficiently long for security
- :param str algorithm: algorithm to sign the token with, e.g. ``"ES256"``
- :param dict headers: additional JWT header fields, e.g. ``dict(kid="my-key-id")``
+ :param str algorithm: algorithm to sign the token with, e.g. ``"ES256"``.
+ If ``headers`` includes ``alg``, it will be preferred to this parameter.
+ :param dict headers: additional JWT header fields, e.g. ``dict(kid="my-key-id")``.
:param json.JSONEncoder json_encoder: custom JSON encoder for ``payload`` and ``headers``
:rtype: str
:returns: a JSON Web Token
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index 6d13699..6cefb2c 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -77,7 +77,7 @@ class PyJWS:
self,
payload: bytes,
key: str,
- algorithm: str = "HS256",
+ algorithm: Optional[str] = "HS256",
headers: Optional[Dict] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
) -> str:
@@ -86,6 +86,10 @@ class PyJWS:
if algorithm is None:
algorithm = "none"
+ # Prefer headers["alg"] if present to algorithm parameter.
+ if headers and "alg" in headers and headers["alg"]:
+ algorithm = headers["alg"]
+
if algorithm not in self._valid_algs:
pass
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 55a8f29..48a9316 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -38,7 +38,7 @@ class PyJWT:
self,
payload: Dict[str, Any],
key: str,
- algorithm: str = "HS256",
+ algorithm: Optional[str] = "HS256",
headers: Optional[Dict] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
) -> str:
diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py
index b731edc..fbeda02 100644
--- a/tests/test_api_jws.py
+++ b/tests/test_api_jws.py
@@ -166,6 +166,32 @@ class TestJWS:
exception = context.value
assert str(exception) == "Algorithm not supported"
+ def test_encode_with_headers_alg_none(self, jws, payload):
+ msg = jws.encode(payload, key=None, headers={"alg": "none"})
+ with pytest.raises(DecodeError) as context:
+ jws.decode(msg, algorithms=["none"])
+ assert str(context.value) == "Signature verification failed"
+
+ @crypto_required
+ def test_encode_with_headers_alg_es256(self, jws, payload):
+ with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
+ priv_key = load_pem_private_key(ec_priv_file.read(), password=None)
+ with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
+ pub_key = load_pem_public_key(ec_pub_file.read())
+
+ msg = jws.encode(payload, priv_key, headers={"alg": "ES256"})
+ assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"])
+
+ @crypto_required
+ def test_encode_with_alg_hs256_and_headers_alg_es256(self, jws, payload):
+ with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
+ priv_key = load_pem_private_key(ec_priv_file.read(), password=None)
+ with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
+ pub_key = load_pem_public_key(ec_pub_file.read())
+
+ msg = jws.encode(payload, priv_key, algorithm="HS256", headers={"alg": "ES256"})
+ assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"])
+
def test_decode_algorithm_param_should_be_case_sensitive(self, jws):
example_jws = (
"eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256