summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThitat Auareesuksakul <thitat@flux.ci>2023-05-10 04:03:15 +0900
committerGitHub <noreply@github.com>2023-05-09 15:03:15 -0400
commitc35e59b9f2c0bc0cf1a71b440a115d997f1e0535 (patch)
tree5d66b73e1edc9595355cce63fbdb4512c70af721
parent6a273419949b68ddccbe3867fd4bd8680cacf097 (diff)
downloadpyjwt-c35e59b9f2c0bc0cf1a71b440a115d997f1e0535.tar.gz
Add `as_dict` option to `Algorithm.to_jwk` (#881)
* Add `as_dict` option to `Algorithm.to_jwt` * Update unit tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixup! Add `as_dict` option to `Algorithm.to_jwt` * fixup! Add `as_dict` option to `Algorithm.to_jwt` * fixup! Update unit tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix type errors * Fix tox test errors * Fix typing for Python 3.7 * Add OKP jwk tests * Add `pragma: no cover` to method overloads * Add pragma: no cover to exclude lines --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-rw-r--r--jwt/algorithms.py139
-rw-r--r--pyproject.toml2
-rw-r--r--setup.cfg2
-rw-r--r--tests/test_algorithms.py89
4 files changed, 173 insertions, 59 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 4bcf81b..ed18715 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -3,8 +3,9 @@ from __future__ import annotations
import hashlib
import hmac
import json
+import sys
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, cast
+from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload
from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict
@@ -20,6 +21,12 @@ from .utils import (
to_base64url_uint,
)
+if sys.version_info >= (3, 8):
+ from typing import Literal
+else:
+ from typing_extensions import Literal
+
+
try:
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
@@ -184,9 +191,21 @@ class Algorithm(ABC):
for the specified message and key values.
"""
+ @overload
+ @staticmethod
+ @abstractmethod
+ def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
@staticmethod
@abstractmethod
- def to_jwk(key_obj) -> str:
+ def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ @abstractmethod
+ def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
"""
Serializes a given key into a JWK
"""
@@ -221,7 +240,7 @@ class NoneAlgorithm(Algorithm):
return False
@staticmethod
- def to_jwk(key_obj: Any) -> NoReturn:
+ def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
raise NotImplementedError()
@staticmethod
@@ -253,14 +272,27 @@ class HMACAlgorithm(Algorithm):
return key_bytes
+ @overload
@staticmethod
- def to_jwk(key_obj: str | bytes) -> str:
- return json.dumps(
- {
- "k": base64url_encode(force_bytes(key_obj)).decode(),
- "kty": "oct",
- }
- )
+ def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
+ jwk = {
+ "k": base64url_encode(force_bytes(key_obj)).decode(),
+ "kty": "oct",
+ }
+
+ if as_dict:
+ return jwk
+ else:
+ return json.dumps(jwk)
@staticmethod
def from_jwk(jwk: str | JWKDict) -> bytes:
@@ -320,8 +352,20 @@ if has_crypto:
except ValueError:
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
+ @overload
@staticmethod
- def to_jwk(key_obj: AllowedRSAKeys) -> str:
+ def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(
+ key_obj: AllowedRSAKeys, as_dict: bool = False
+ ) -> Union[JWKDict, str]:
obj: dict[str, Any] | None = None
if hasattr(key_obj, "private_numbers"):
@@ -354,7 +398,10 @@ if has_crypto:
else:
raise InvalidKeyError("Not a public or private key")
- return json.dumps(obj)
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
@@ -503,8 +550,20 @@ if has_crypto:
except InvalidSignature:
return False
+ @overload
@staticmethod
- def to_jwk(key_obj: AllowedECKeys) -> str:
+ def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(
+ key_obj: AllowedECKeys, as_dict: bool = False
+ ) -> Union[JWKDict, str]:
if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey):
@@ -535,7 +594,10 @@ if has_crypto:
key_obj.private_numbers().private_value
).decode()
- return json.dumps(obj)
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
@@ -707,21 +769,35 @@ if has_crypto:
except InvalidSignature:
return False
+ @overload
@staticmethod
- def to_jwk(key: AllowedOKPKeys) -> str:
+ def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
x = key.public_bytes(
encoding=Encoding.Raw,
format=PublicFormat.Raw,
)
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
- return json.dumps(
- {
- "x": base64url_encode(force_bytes(x)).decode(),
- "kty": "OKP",
- "crv": crv,
- }
- )
+
+ obj = {
+ "x": base64url_encode(force_bytes(x)).decode(),
+ "kty": "OKP",
+ "crv": crv,
+ }
+
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
d = key.private_bytes(
@@ -736,14 +812,17 @@ if has_crypto:
)
crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
- return json.dumps(
- {
- "x": base64url_encode(force_bytes(x)).decode(),
- "d": base64url_encode(force_bytes(d)).decode(),
- "kty": "OKP",
- "crv": crv,
- }
- )
+ obj = {
+ "x": base64url_encode(force_bytes(x)).decode(),
+ "d": base64url_encode(force_bytes(d)).decode(),
+ "kty": "OKP",
+ "crv": crv,
+ }
+
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
raise InvalidKeyError("Not a public or private key")
diff --git a/pyproject.toml b/pyproject.toml
index 40dfb7b..613b91d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,7 +12,7 @@ source = ["jwt", ".tox/*/site-packages"]
[tool.coverage.report]
show_missing = true
-exclude_lines = ["if TYPE_CHECKING:"]
+exclude_lines = ["if TYPE_CHECKING:", "pragma: no cover"]
[tool.isort]
profile = "black"
diff --git a/setup.cfg b/setup.cfg
index 08a38b4..44ccf0b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -35,6 +35,8 @@ zip_safe = false
include_package_data = true
python_requires = >=3.7
packages = find:
+install_requires =
+ typing_extensions; python_version<="3.7"
[options.package_data]
* = py.typed
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index 7a90376..1a39552 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -1,6 +1,6 @@
import base64
import json
-from typing import cast
+from typing import Any, cast
import pytest
@@ -91,11 +91,15 @@ class TestAlgorithms:
signature = algo.sign(b"Hello World!", key)
assert algo.verify(b"Hello World!", key, signature)
- def test_hmac_to_jwk_returns_correct_values(self):
+ @pytest.mark.parametrize("as_dict", (False, True))
+ def test_hmac_to_jwk_returns_correct_values(self, as_dict):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
- key = algo.to_jwk("secret")
+ key: Any = algo.to_jwk("secret", as_dict=as_dict)
- assert json.loads(key) == {"kty": "oct", "k": "c2VjcmV0"}
+ if not as_dict:
+ key = json.loads(key)
+
+ assert key == {"kty": "oct", "k": "c2VjcmV0"}
def test_hmac_from_jwk_should_raise_exception_if_not_hmac_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
@@ -261,13 +265,17 @@ class TestAlgorithms:
assert parsed_key.public_numbers() == orig_key.public_numbers()
@crypto_required
- def test_ec_to_jwk_returns_correct_values_for_public_key(self):
+ @pytest.mark.parametrize("as_dict", (False, True))
+ def test_ec_to_jwk_returns_correct_values_for_public_key(self, as_dict):
algo = ECAlgorithm(ECAlgorithm.SHA256)
with open(key_path("testkey_ec.pub")) as keyfile:
pub_key = algo.prepare_key(keyfile.read())
- key = algo.to_jwk(pub_key)
+ key: Any = algo.to_jwk(pub_key, as_dict=as_dict)
+
+ if not as_dict:
+ key = json.loads(key)
expected = {
"kty": "EC",
@@ -276,16 +284,20 @@ class TestAlgorithms:
"y": "t2G02kbWiOqimYfQAfnARdp2CTycsJPhwA8rn1Cn0SQ",
}
- assert json.loads(key) == expected
+ assert key == expected
@crypto_required
- def test_ec_to_jwk_returns_correct_values_for_private_key(self):
+ @pytest.mark.parametrize("as_dict", (False, True))
+ def test_ec_to_jwk_returns_correct_values_for_private_key(self, as_dict):
algo = ECAlgorithm(ECAlgorithm.SHA256)
with open(key_path("testkey_ec.priv")) as keyfile:
priv_key = algo.prepare_key(keyfile.read())
- key = algo.to_jwk(priv_key)
+ key: Any = algo.to_jwk(priv_key, as_dict=as_dict)
+
+ if not as_dict:
+ key = json.loads(key)
expected = {
"kty": "EC",
@@ -295,17 +307,18 @@ class TestAlgorithms:
"d": "2nninfu2jMHDwAbn9oERUhRADS6duQaJEadybLaa0YQ",
}
- assert json.loads(key) == expected
+ assert key == expected
@crypto_required
def test_ec_to_jwk_raises_exception_on_invalid_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
with pytest.raises(InvalidKeyError):
- algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type]
+ algo.to_jwk({"not": "a valid key"}) # type: ignore[call-overload]
@crypto_required
- def test_ec_to_jwk_with_valid_curves(self):
+ @pytest.mark.parametrize("as_dict", (False, True))
+ def test_ec_to_jwk_with_valid_curves(self, as_dict):
tests = {
"P-256": ECAlgorithm.SHA256,
"P-384": ECAlgorithm.SHA384,
@@ -317,11 +330,21 @@ class TestAlgorithms:
with open(key_path(f"jwk_ec_pub_{curve}.json")) as keyfile:
pub_key = algo.from_jwk(keyfile.read())
- assert json.loads(algo.to_jwk(pub_key))["crv"] == curve
+ jwk: Any = algo.to_jwk(pub_key, as_dict=as_dict)
+
+ if not as_dict:
+ jwk = json.loads(jwk)
+
+ assert jwk["crv"] == curve
with open(key_path(f"jwk_ec_key_{curve}.json")) as keyfile:
priv_key = algo.from_jwk(keyfile.read())
- assert json.loads(algo.to_jwk(priv_key))["crv"] == curve
+ jwk = algo.to_jwk(priv_key, as_dict=as_dict)
+
+ if not as_dict:
+ jwk = json.loads(jwk)
+
+ assert jwk["crv"] == curve
@crypto_required
def test_ec_to_jwk_with_invalid_curve(self):
@@ -440,13 +463,17 @@ class TestAlgorithms:
algo.from_jwk('{"kty": "RSA"}')
@crypto_required
- def test_rsa_to_jwk_returns_correct_values_for_public_key(self):
+ @pytest.mark.parametrize("as_dict", (False, True))
+ def test_rsa_to_jwk_returns_correct_values_for_public_key(self, as_dict):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
with open(key_path("testkey_rsa.pub")) as keyfile:
pub_key = algo.prepare_key(keyfile.read())
- key = algo.to_jwk(pub_key)
+ key: Any = algo.to_jwk(pub_key, as_dict=as_dict)
+
+ if not as_dict:
+ key = json.loads(key)
expected = {
"e": "AQAB",
@@ -461,16 +488,20 @@ class TestAlgorithms:
"sNruF3ogJWNq1Lyn_ijPQnkPLpZHyhvuiycYcI3DiQ"
),
}
- assert json.loads(key) == expected
+ assert key == expected
@crypto_required
- def test_rsa_to_jwk_returns_correct_values_for_private_key(self):
+ @pytest.mark.parametrize("as_dict", (False, True))
+ def test_rsa_to_jwk_returns_correct_values_for_private_key(self, as_dict):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
with open(key_path("testkey_rsa.priv")) as keyfile:
priv_key = algo.prepare_key(keyfile.read())
- key = algo.to_jwk(priv_key)
+ key: Any = algo.to_jwk(priv_key, as_dict=as_dict)
+
+ if not as_dict:
+ key = json.loads(key)
expected = {
"key_ops": ["sign"],
@@ -518,14 +549,14 @@ class TestAlgorithms:
"AuKhin-kc4mh9ssDXRQZwlMymZP0QtaxUDw_nlfVrUCZgO7L1_ZsUTk"
),
}
- assert json.loads(key) == expected
+ assert key == expected
@crypto_required
def test_rsa_to_jwk_raises_exception_on_invalid_key(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
with pytest.raises(InvalidKeyError):
- algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type]
+ algo.to_jwk({"not": "a valid key"}) # type: ignore[call-overload]
@crypto_required
def test_rsa_from_jwk_raises_exception_on_invalid_key(self):
@@ -923,7 +954,8 @@ class TestOKPAlgorithms:
with pytest.raises(InvalidKeyError):
algo.from_jwk(v)
- def test_okp_ed25519_to_jwk_works_with_from_jwk(self):
+ @pytest.mark.parametrize("as_dict", (False, True))
+ def test_okp_ed25519_to_jwk_works_with_from_jwk(self, as_dict):
algo = OKPAlgorithm()
with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile:
@@ -932,9 +964,9 @@ class TestOKPAlgorithms:
with open(key_path("jwk_okp_pub_Ed25519.json")) as keyfile:
pub_key_1 = cast(Ed25519PublicKey, algo.from_jwk(keyfile.read()))
- pub = algo.to_jwk(pub_key_1)
+ pub = algo.to_jwk(pub_key_1, as_dict=as_dict)
pub_key_2 = algo.from_jwk(pub)
- pri = algo.to_jwk(priv_key_1)
+ pri = algo.to_jwk(priv_key_1, as_dict=as_dict)
priv_key_2 = cast(Ed25519PrivateKey, algo.from_jwk(pri))
signature_1 = algo.sign(b"Hello World!", priv_key_1)
@@ -946,7 +978,7 @@ class TestOKPAlgorithms:
algo = OKPAlgorithm()
with pytest.raises(InvalidKeyError):
- algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type]
+ algo.to_jwk({"not": "a valid key"}) # type: ignore[call-overload]
def test_okp_ed448_jwk_private_key_should_parse_and_verify(self):
algo = OKPAlgorithm()
@@ -1032,7 +1064,8 @@ class TestOKPAlgorithms:
with pytest.raises(InvalidKeyError):
algo.from_jwk(v)
- def test_okp_ed448_to_jwk_works_with_from_jwk(self):
+ @pytest.mark.parametrize("as_dict", (False, True))
+ def test_okp_ed448_to_jwk_works_with_from_jwk(self, as_dict):
algo = OKPAlgorithm()
with open(key_path("jwk_okp_key_Ed448.json")) as keyfile:
@@ -1041,9 +1074,9 @@ class TestOKPAlgorithms:
with open(key_path("jwk_okp_pub_Ed448.json")) as keyfile:
pub_key_1 = cast(Ed448PublicKey, algo.from_jwk(keyfile.read()))
- pub = algo.to_jwk(pub_key_1)
+ pub = algo.to_jwk(pub_key_1, as_dict=as_dict)
pub_key_2 = algo.from_jwk(pub)
- pri = algo.to_jwk(priv_key_1)
+ pri = algo.to_jwk(priv_key_1, as_dict=as_dict)
priv_key_2 = cast(Ed448PrivateKey, algo.from_jwk(pri))
signature_1 = algo.sign(b"Hello World!", priv_key_1)