summaryrefslogtreecommitdiff
path: root/jwt/api_jws.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/api_jws.py')
-rw-r--r--jwt/api_jws.py53
1 files changed, 31 insertions, 22 deletions
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index 16ec846..f8c8048 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -73,6 +73,23 @@ class PyJWS:
"""
return list(self._valid_algs)
+ def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
+ """
+ For a given string name, return the matching Algorithm object.
+
+ Example usage:
+
+ >>> jws_obj.get_algorithm_by_name("RS256")
+ """
+ try:
+ return self._algorithms[alg_name]
+ except KeyError as e:
+ if not has_crypto and alg_name in requires_cryptography:
+ raise NotImplementedError(
+ f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?"
+ ) from e
+ raise NotImplementedError("Algorithm not supported") from e
+
def encode(
self,
payload: bytes,
@@ -84,21 +101,21 @@ class PyJWS:
) -> str:
segments = []
- if algorithm is None:
- algorithm = "none"
+ # declare a new var to narrow the type for type checkers
+ algorithm_: str = algorithm if algorithm is not None else "none"
# Prefer headers values if present to function parameters.
if headers:
headers_alg = headers.get("alg")
if headers_alg:
- algorithm = headers["alg"]
+ algorithm_ = headers["alg"]
headers_b64 = headers.get("b64")
if headers_b64 is False:
is_payload_detached = True
# Header
- header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any]
+ header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any]
if headers:
self._validate_headers(headers)
@@ -128,17 +145,9 @@ class PyJWS:
# Segments
signing_input = b".".join(segments)
- try:
- alg_obj = self._algorithms[algorithm]
- key = alg_obj.prepare_key(key)
- signature = alg_obj.sign(signing_input, key)
-
- except KeyError as e:
- if not has_crypto and algorithm in requires_cryptography:
- raise NotImplementedError(
- f"Algorithm '{algorithm}' could not be found. Do you have cryptography installed?"
- ) from e
- raise NotImplementedError("Algorithm not supported") from e
+ alg_obj = self.get_algorithm_by_name(algorithm_)
+ key = alg_obj.prepare_key(key)
+ signature = alg_obj.sign(signing_input, key)
segments.append(base64url_encode(signature))
@@ -262,14 +271,13 @@ class PyJWS:
raise InvalidAlgorithmError("The specified alg value is not allowed")
try:
- alg_obj = self._algorithms[alg]
- key = alg_obj.prepare_key(key)
-
- if not alg_obj.verify(signing_input, key, signature):
- raise InvalidSignatureError("Signature verification failed")
-
- except KeyError as e:
+ alg_obj = self.get_algorithm_by_name(alg)
+ except NotImplementedError as e:
raise InvalidAlgorithmError("Algorithm not supported") from e
+ key = alg_obj.prepare_key(key)
+
+ if not alg_obj.verify(signing_input, key, signature):
+ raise InvalidSignatureError("Signature verification failed")
def _validate_headers(self, headers):
if "kid" in headers:
@@ -286,4 +294,5 @@ decode_complete = _jws_global_obj.decode_complete
decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm
+get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name
get_unverified_header = _jws_global_obj.get_unverified_header