diff options
Diffstat (limited to 'jwt/api_jws.py')
-rw-r--r-- | jwt/api_jws.py | 53 |
1 files changed, 31 insertions, 22 deletions
diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 16ec846..f8c8048 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -73,6 +73,23 @@ class PyJWS: """ return list(self._valid_algs) + def get_algorithm_by_name(self, alg_name: str) -> Algorithm: + """ + For a given string name, return the matching Algorithm object. + + Example usage: + + >>> jws_obj.get_algorithm_by_name("RS256") + """ + try: + return self._algorithms[alg_name] + except KeyError as e: + if not has_crypto and alg_name in requires_cryptography: + raise NotImplementedError( + f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" + ) from e + raise NotImplementedError("Algorithm not supported") from e + def encode( self, payload: bytes, @@ -84,21 +101,21 @@ class PyJWS: ) -> str: segments = [] - if algorithm is None: - algorithm = "none" + # declare a new var to narrow the type for type checkers + algorithm_: str = algorithm if algorithm is not None else "none" # Prefer headers values if present to function parameters. if headers: headers_alg = headers.get("alg") if headers_alg: - algorithm = headers["alg"] + algorithm_ = headers["alg"] headers_b64 = headers.get("b64") if headers_b64 is False: is_payload_detached = True # Header - header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any] + header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any] if headers: self._validate_headers(headers) @@ -128,17 +145,9 @@ class PyJWS: # Segments signing_input = b".".join(segments) - try: - alg_obj = self._algorithms[algorithm] - key = alg_obj.prepare_key(key) - signature = alg_obj.sign(signing_input, key) - - except KeyError as e: - if not has_crypto and algorithm in requires_cryptography: - raise NotImplementedError( - f"Algorithm '{algorithm}' could not be found. Do you have cryptography installed?" - ) from e - raise NotImplementedError("Algorithm not supported") from e + alg_obj = self.get_algorithm_by_name(algorithm_) + key = alg_obj.prepare_key(key) + signature = alg_obj.sign(signing_input, key) segments.append(base64url_encode(signature)) @@ -262,14 +271,13 @@ class PyJWS: raise InvalidAlgorithmError("The specified alg value is not allowed") try: - alg_obj = self._algorithms[alg] - key = alg_obj.prepare_key(key) - - if not alg_obj.verify(signing_input, key, signature): - raise InvalidSignatureError("Signature verification failed") - - except KeyError as e: + alg_obj = self.get_algorithm_by_name(alg) + except NotImplementedError as e: raise InvalidAlgorithmError("Algorithm not supported") from e + key = alg_obj.prepare_key(key) + + if not alg_obj.verify(signing_input, key, signature): + raise InvalidSignatureError("Signature verification failed") def _validate_headers(self, headers): if "kid" in headers: @@ -286,4 +294,5 @@ decode_complete = _jws_global_obj.decode_complete decode = _jws_global_obj.decode register_algorithm = _jws_global_obj.register_algorithm unregister_algorithm = _jws_global_obj.unregister_algorithm +get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name get_unverified_header = _jws_global_obj.get_unverified_header |