From c35e59b9f2c0bc0cf1a71b440a115d997f1e0535 Mon Sep 17 00:00:00 2001 From: Thitat Auareesuksakul Date: Wed, 10 May 2023 04:03:15 +0900 Subject: Add `as_dict` option to `Algorithm.to_jwk` (#881) * Add `as_dict` option to `Algorithm.to_jwt` * Update unit tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixup! Add `as_dict` option to `Algorithm.to_jwt` * fixup! Add `as_dict` option to `Algorithm.to_jwt` * fixup! Update unit tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix type errors * Fix tox test errors * Fix typing for Python 3.7 * Add OKP jwk tests * Add `pragma: no cover` to method overloads * Add pragma: no cover to exclude lines --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- jwt/algorithms.py | 139 +++++++++++++++++++++++++++++++++++++---------- pyproject.toml | 2 +- setup.cfg | 2 + tests/test_algorithms.py | 89 ++++++++++++++++++++---------- 4 files changed, 173 insertions(+), 59 deletions(-) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 4bcf81b..ed18715 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -3,8 +3,9 @@ from __future__ import annotations import hashlib import hmac import json +import sys from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, cast +from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload from .exceptions import InvalidKeyError from .types import HashlibHash, JWKDict @@ -20,6 +21,12 @@ from .utils import ( to_base64url_uint, ) +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + + try: from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend @@ -184,9 +191,21 @@ class Algorithm(ABC): for the specified message and key values. """ + @overload + @staticmethod + @abstractmethod + def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover + + @overload @staticmethod @abstractmethod - def to_jwk(key_obj) -> str: + def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover + + @staticmethod + @abstractmethod + def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]: """ Serializes a given key into a JWK """ @@ -221,7 +240,7 @@ class NoneAlgorithm(Algorithm): return False @staticmethod - def to_jwk(key_obj: Any) -> NoReturn: + def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn: raise NotImplementedError() @staticmethod @@ -253,14 +272,27 @@ class HMACAlgorithm(Algorithm): return key_bytes + @overload @staticmethod - def to_jwk(key_obj: str | bytes) -> str: - return json.dumps( - { - "k": base64url_encode(force_bytes(key_obj)).decode(), - "kty": "oct", - } - ) + def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover + + @overload + @staticmethod + def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover + + @staticmethod + def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]: + jwk = { + "k": base64url_encode(force_bytes(key_obj)).decode(), + "kty": "oct", + } + + if as_dict: + return jwk + else: + return json.dumps(jwk) @staticmethod def from_jwk(jwk: str | JWKDict) -> bytes: @@ -320,8 +352,20 @@ if has_crypto: except ValueError: return cast(RSAPublicKey, load_pem_public_key(key_bytes)) + @overload @staticmethod - def to_jwk(key_obj: AllowedRSAKeys) -> str: + def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover + + @overload + @staticmethod + def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover + + @staticmethod + def to_jwk( + key_obj: AllowedRSAKeys, as_dict: bool = False + ) -> Union[JWKDict, str]: obj: dict[str, Any] | None = None if hasattr(key_obj, "private_numbers"): @@ -354,7 +398,10 @@ if has_crypto: else: raise InvalidKeyError("Not a public or private key") - return json.dumps(obj) + if as_dict: + return obj + else: + return json.dumps(obj) @staticmethod def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: @@ -503,8 +550,20 @@ if has_crypto: except InvalidSignature: return False + @overload @staticmethod - def to_jwk(key_obj: AllowedECKeys) -> str: + def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover + + @overload + @staticmethod + def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover + + @staticmethod + def to_jwk( + key_obj: AllowedECKeys, as_dict: bool = False + ) -> Union[JWKDict, str]: if isinstance(key_obj, EllipticCurvePrivateKey): public_numbers = key_obj.public_key().public_numbers() elif isinstance(key_obj, EllipticCurvePublicKey): @@ -535,7 +594,10 @@ if has_crypto: key_obj.private_numbers().private_value ).decode() - return json.dumps(obj) + if as_dict: + return obj + else: + return json.dumps(obj) @staticmethod def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: @@ -707,21 +769,35 @@ if has_crypto: except InvalidSignature: return False + @overload @staticmethod - def to_jwk(key: AllowedOKPKeys) -> str: + def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover + + @overload + @staticmethod + def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover + + @staticmethod + def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]: if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): x = key.public_bytes( encoding=Encoding.Raw, format=PublicFormat.Raw, ) crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" - return json.dumps( - { - "x": base64url_encode(force_bytes(x)).decode(), - "kty": "OKP", - "crv": crv, - } - ) + + obj = { + "x": base64url_encode(force_bytes(x)).decode(), + "kty": "OKP", + "crv": crv, + } + + if as_dict: + return obj + else: + return json.dumps(obj) if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): d = key.private_bytes( @@ -736,14 +812,17 @@ if has_crypto: ) crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" - return json.dumps( - { - "x": base64url_encode(force_bytes(x)).decode(), - "d": base64url_encode(force_bytes(d)).decode(), - "kty": "OKP", - "crv": crv, - } - ) + obj = { + "x": base64url_encode(force_bytes(x)).decode(), + "d": base64url_encode(force_bytes(d)).decode(), + "kty": "OKP", + "crv": crv, + } + + if as_dict: + return obj + else: + return json.dumps(obj) raise InvalidKeyError("Not a public or private key") diff --git a/pyproject.toml b/pyproject.toml index 40dfb7b..613b91d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ source = ["jwt", ".tox/*/site-packages"] [tool.coverage.report] show_missing = true -exclude_lines = ["if TYPE_CHECKING:"] +exclude_lines = ["if TYPE_CHECKING:", "pragma: no cover"] [tool.isort] profile = "black" diff --git a/setup.cfg b/setup.cfg index 08a38b4..44ccf0b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,8 @@ zip_safe = false include_package_data = true python_requires = >=3.7 packages = find: +install_requires = + typing_extensions; python_version<="3.7" [options.package_data] * = py.typed diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 7a90376..1a39552 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,6 +1,6 @@ import base64 import json -from typing import cast +from typing import Any, cast import pytest @@ -91,11 +91,15 @@ class TestAlgorithms: signature = algo.sign(b"Hello World!", key) assert algo.verify(b"Hello World!", key, signature) - def test_hmac_to_jwk_returns_correct_values(self): + @pytest.mark.parametrize("as_dict", (False, True)) + def test_hmac_to_jwk_returns_correct_values(self, as_dict): algo = HMACAlgorithm(HMACAlgorithm.SHA256) - key = algo.to_jwk("secret") + key: Any = algo.to_jwk("secret", as_dict=as_dict) - assert json.loads(key) == {"kty": "oct", "k": "c2VjcmV0"} + if not as_dict: + key = json.loads(key) + + assert key == {"kty": "oct", "k": "c2VjcmV0"} def test_hmac_from_jwk_should_raise_exception_if_not_hmac_key(self): algo = HMACAlgorithm(HMACAlgorithm.SHA256) @@ -261,13 +265,17 @@ class TestAlgorithms: assert parsed_key.public_numbers() == orig_key.public_numbers() @crypto_required - def test_ec_to_jwk_returns_correct_values_for_public_key(self): + @pytest.mark.parametrize("as_dict", (False, True)) + def test_ec_to_jwk_returns_correct_values_for_public_key(self, as_dict): algo = ECAlgorithm(ECAlgorithm.SHA256) with open(key_path("testkey_ec.pub")) as keyfile: pub_key = algo.prepare_key(keyfile.read()) - key = algo.to_jwk(pub_key) + key: Any = algo.to_jwk(pub_key, as_dict=as_dict) + + if not as_dict: + key = json.loads(key) expected = { "kty": "EC", @@ -276,16 +284,20 @@ class TestAlgorithms: "y": "t2G02kbWiOqimYfQAfnARdp2CTycsJPhwA8rn1Cn0SQ", } - assert json.loads(key) == expected + assert key == expected @crypto_required - def test_ec_to_jwk_returns_correct_values_for_private_key(self): + @pytest.mark.parametrize("as_dict", (False, True)) + def test_ec_to_jwk_returns_correct_values_for_private_key(self, as_dict): algo = ECAlgorithm(ECAlgorithm.SHA256) with open(key_path("testkey_ec.priv")) as keyfile: priv_key = algo.prepare_key(keyfile.read()) - key = algo.to_jwk(priv_key) + key: Any = algo.to_jwk(priv_key, as_dict=as_dict) + + if not as_dict: + key = json.loads(key) expected = { "kty": "EC", @@ -295,17 +307,18 @@ class TestAlgorithms: "d": "2nninfu2jMHDwAbn9oERUhRADS6duQaJEadybLaa0YQ", } - assert json.loads(key) == expected + assert key == expected @crypto_required def test_ec_to_jwk_raises_exception_on_invalid_key(self): algo = ECAlgorithm(ECAlgorithm.SHA256) with pytest.raises(InvalidKeyError): - algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type] + algo.to_jwk({"not": "a valid key"}) # type: ignore[call-overload] @crypto_required - def test_ec_to_jwk_with_valid_curves(self): + @pytest.mark.parametrize("as_dict", (False, True)) + def test_ec_to_jwk_with_valid_curves(self, as_dict): tests = { "P-256": ECAlgorithm.SHA256, "P-384": ECAlgorithm.SHA384, @@ -317,11 +330,21 @@ class TestAlgorithms: with open(key_path(f"jwk_ec_pub_{curve}.json")) as keyfile: pub_key = algo.from_jwk(keyfile.read()) - assert json.loads(algo.to_jwk(pub_key))["crv"] == curve + jwk: Any = algo.to_jwk(pub_key, as_dict=as_dict) + + if not as_dict: + jwk = json.loads(jwk) + + assert jwk["crv"] == curve with open(key_path(f"jwk_ec_key_{curve}.json")) as keyfile: priv_key = algo.from_jwk(keyfile.read()) - assert json.loads(algo.to_jwk(priv_key))["crv"] == curve + jwk = algo.to_jwk(priv_key, as_dict=as_dict) + + if not as_dict: + jwk = json.loads(jwk) + + assert jwk["crv"] == curve @crypto_required def test_ec_to_jwk_with_invalid_curve(self): @@ -440,13 +463,17 @@ class TestAlgorithms: algo.from_jwk('{"kty": "RSA"}') @crypto_required - def test_rsa_to_jwk_returns_correct_values_for_public_key(self): + @pytest.mark.parametrize("as_dict", (False, True)) + def test_rsa_to_jwk_returns_correct_values_for_public_key(self, as_dict): algo = RSAAlgorithm(RSAAlgorithm.SHA256) with open(key_path("testkey_rsa.pub")) as keyfile: pub_key = algo.prepare_key(keyfile.read()) - key = algo.to_jwk(pub_key) + key: Any = algo.to_jwk(pub_key, as_dict=as_dict) + + if not as_dict: + key = json.loads(key) expected = { "e": "AQAB", @@ -461,16 +488,20 @@ class TestAlgorithms: "sNruF3ogJWNq1Lyn_ijPQnkPLpZHyhvuiycYcI3DiQ" ), } - assert json.loads(key) == expected + assert key == expected @crypto_required - def test_rsa_to_jwk_returns_correct_values_for_private_key(self): + @pytest.mark.parametrize("as_dict", (False, True)) + def test_rsa_to_jwk_returns_correct_values_for_private_key(self, as_dict): algo = RSAAlgorithm(RSAAlgorithm.SHA256) with open(key_path("testkey_rsa.priv")) as keyfile: priv_key = algo.prepare_key(keyfile.read()) - key = algo.to_jwk(priv_key) + key: Any = algo.to_jwk(priv_key, as_dict=as_dict) + + if not as_dict: + key = json.loads(key) expected = { "key_ops": ["sign"], @@ -518,14 +549,14 @@ class TestAlgorithms: "AuKhin-kc4mh9ssDXRQZwlMymZP0QtaxUDw_nlfVrUCZgO7L1_ZsUTk" ), } - assert json.loads(key) == expected + assert key == expected @crypto_required def test_rsa_to_jwk_raises_exception_on_invalid_key(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) with pytest.raises(InvalidKeyError): - algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type] + algo.to_jwk({"not": "a valid key"}) # type: ignore[call-overload] @crypto_required def test_rsa_from_jwk_raises_exception_on_invalid_key(self): @@ -923,7 +954,8 @@ class TestOKPAlgorithms: with pytest.raises(InvalidKeyError): algo.from_jwk(v) - def test_okp_ed25519_to_jwk_works_with_from_jwk(self): + @pytest.mark.parametrize("as_dict", (False, True)) + def test_okp_ed25519_to_jwk_works_with_from_jwk(self, as_dict): algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile: @@ -932,9 +964,9 @@ class TestOKPAlgorithms: with open(key_path("jwk_okp_pub_Ed25519.json")) as keyfile: pub_key_1 = cast(Ed25519PublicKey, algo.from_jwk(keyfile.read())) - pub = algo.to_jwk(pub_key_1) + pub = algo.to_jwk(pub_key_1, as_dict=as_dict) pub_key_2 = algo.from_jwk(pub) - pri = algo.to_jwk(priv_key_1) + pri = algo.to_jwk(priv_key_1, as_dict=as_dict) priv_key_2 = cast(Ed25519PrivateKey, algo.from_jwk(pri)) signature_1 = algo.sign(b"Hello World!", priv_key_1) @@ -946,7 +978,7 @@ class TestOKPAlgorithms: algo = OKPAlgorithm() with pytest.raises(InvalidKeyError): - algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type] + algo.to_jwk({"not": "a valid key"}) # type: ignore[call-overload] def test_okp_ed448_jwk_private_key_should_parse_and_verify(self): algo = OKPAlgorithm() @@ -1032,7 +1064,8 @@ class TestOKPAlgorithms: with pytest.raises(InvalidKeyError): algo.from_jwk(v) - def test_okp_ed448_to_jwk_works_with_from_jwk(self): + @pytest.mark.parametrize("as_dict", (False, True)) + def test_okp_ed448_to_jwk_works_with_from_jwk(self, as_dict): algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed448.json")) as keyfile: @@ -1041,9 +1074,9 @@ class TestOKPAlgorithms: with open(key_path("jwk_okp_pub_Ed448.json")) as keyfile: pub_key_1 = cast(Ed448PublicKey, algo.from_jwk(keyfile.read())) - pub = algo.to_jwk(pub_key_1) + pub = algo.to_jwk(pub_key_1, as_dict=as_dict) pub_key_2 = algo.from_jwk(pub) - pri = algo.to_jwk(priv_key_1) + pri = algo.to_jwk(priv_key_1, as_dict=as_dict) priv_key_2 = cast(Ed448PrivateKey, algo.from_jwk(pri)) signature_1 = algo.sign(b"Hello World!", priv_key_1) -- cgit v1.2.1