summaryrefslogtreecommitdiff
path: root/jwt/algorithms.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r--jwt/algorithms.py139
1 files changed, 109 insertions, 30 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 4bcf81b..ed18715 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -3,8 +3,9 @@ from __future__ import annotations
import hashlib
import hmac
import json
+import sys
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, cast
+from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload
from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict
@@ -20,6 +21,12 @@ from .utils import (
to_base64url_uint,
)
+if sys.version_info >= (3, 8):
+ from typing import Literal
+else:
+ from typing_extensions import Literal
+
+
try:
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
@@ -184,9 +191,21 @@ class Algorithm(ABC):
for the specified message and key values.
"""
+ @overload
+ @staticmethod
+ @abstractmethod
+ def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
@staticmethod
@abstractmethod
- def to_jwk(key_obj) -> str:
+ def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ @abstractmethod
+ def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
"""
Serializes a given key into a JWK
"""
@@ -221,7 +240,7 @@ class NoneAlgorithm(Algorithm):
return False
@staticmethod
- def to_jwk(key_obj: Any) -> NoReturn:
+ def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
raise NotImplementedError()
@staticmethod
@@ -253,14 +272,27 @@ class HMACAlgorithm(Algorithm):
return key_bytes
+ @overload
@staticmethod
- def to_jwk(key_obj: str | bytes) -> str:
- return json.dumps(
- {
- "k": base64url_encode(force_bytes(key_obj)).decode(),
- "kty": "oct",
- }
- )
+ def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
+ jwk = {
+ "k": base64url_encode(force_bytes(key_obj)).decode(),
+ "kty": "oct",
+ }
+
+ if as_dict:
+ return jwk
+ else:
+ return json.dumps(jwk)
@staticmethod
def from_jwk(jwk: str | JWKDict) -> bytes:
@@ -320,8 +352,20 @@ if has_crypto:
except ValueError:
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
+ @overload
@staticmethod
- def to_jwk(key_obj: AllowedRSAKeys) -> str:
+ def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(
+ key_obj: AllowedRSAKeys, as_dict: bool = False
+ ) -> Union[JWKDict, str]:
obj: dict[str, Any] | None = None
if hasattr(key_obj, "private_numbers"):
@@ -354,7 +398,10 @@ if has_crypto:
else:
raise InvalidKeyError("Not a public or private key")
- return json.dumps(obj)
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
@@ -503,8 +550,20 @@ if has_crypto:
except InvalidSignature:
return False
+ @overload
@staticmethod
- def to_jwk(key_obj: AllowedECKeys) -> str:
+ def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(
+ key_obj: AllowedECKeys, as_dict: bool = False
+ ) -> Union[JWKDict, str]:
if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey):
@@ -535,7 +594,10 @@ if has_crypto:
key_obj.private_numbers().private_value
).decode()
- return json.dumps(obj)
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
@@ -707,21 +769,35 @@ if has_crypto:
except InvalidSignature:
return False
+ @overload
@staticmethod
- def to_jwk(key: AllowedOKPKeys) -> str:
+ def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
x = key.public_bytes(
encoding=Encoding.Raw,
format=PublicFormat.Raw,
)
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
- return json.dumps(
- {
- "x": base64url_encode(force_bytes(x)).decode(),
- "kty": "OKP",
- "crv": crv,
- }
- )
+
+ obj = {
+ "x": base64url_encode(force_bytes(x)).decode(),
+ "kty": "OKP",
+ "crv": crv,
+ }
+
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
d = key.private_bytes(
@@ -736,14 +812,17 @@ if has_crypto:
)
crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
- return json.dumps(
- {
- "x": base64url_encode(force_bytes(x)).decode(),
- "d": base64url_encode(force_bytes(d)).decode(),
- "kty": "OKP",
- "crv": crv,
- }
- )
+ obj = {
+ "x": base64url_encode(force_bytes(x)).decode(),
+ "d": base64url_encode(force_bytes(d)).decode(),
+ "kty": "OKP",
+ "crv": crv,
+ }
+
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
raise InvalidKeyError("Not a public or private key")