summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephen Rosen <sirosen@globus.org>2022-07-03 11:18:46 -0400
committerGitHub <noreply@github.com>2022-07-03 11:18:46 -0400
commit2dab2107543d2ffed554aec332438dc16f0b2015 (patch)
tree0cffb919c6b232e44da33be052107c53f02f18b5
parent381a15a21fafb54348e4abc7f354eb5486509a93 (diff)
downloadpyjwt-2dab2107543d2ffed554aec332438dc16f0b2015.tar.gz
Expose get_algorithm_by_name as new method (#773)
* Expose get_algorithm_by_name as new method Looking up an algorithm by name is used internally for signature generation. This encapsulates that functionality in a dedicated method and adds it to the public API. No new tests are needed to exercise the functionality. Rationale: 1. Inside of PyJWS, this improves the code. The KeyError handler is better scoped and the signing code reads more directly. 2. This is part of the path to supporting OIDC at_hash validation as a use-case (see: #295, #296, #314). This is arguably sufficient to consider that use-case supported and close it. However, it is an improvement and step in the right direction in either case. A minor change was needed to satisfy mypy, as a union-typed variable does not narrow its type based on assignments. The easiest resolution is to use a new name, in this case, simply `algorithm -> algorithm_`. * Use get_algorithm_by_name in _verify_signature Rather than catching the KeyError from a dict lookup, catch the NotImplementedError raised by get_algorithm_by_name. This changes the exception seen in the cause under exception chaining but otherwise has no public-facing impact.
-rw-r--r--CHANGELOG.rst2
-rw-r--r--jwt/__init__.py2
-rw-r--r--jwt/api_jws.py53
3 files changed, 35 insertions, 22 deletions
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index c11ce06..81d421e 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -19,6 +19,8 @@ Fixed
Added
~~~~~
- Add to_jwk static method to ECAlgorithm by @leonsmith in https://github.com/jpadilla/pyjwt/pull/732
+- Add ``get_algorithm_by_name`` as a method of ``PyJWS`` objects, and expose
+ the global PyJWS method as part of the public API
`v2.4.0 <https://github.com/jpadilla/pyjwt/compare/2.3.0...2.4.0>`__
-----------------------------------------------------------------------
diff --git a/jwt/__init__.py b/jwt/__init__.py
index 6b3f8ab..a96cc6e 100644
--- a/jwt/__init__.py
+++ b/jwt/__init__.py
@@ -1,6 +1,7 @@
from .api_jwk import PyJWK, PyJWKSet
from .api_jws import (
PyJWS,
+ get_algorithm_by_name,
get_unverified_header,
register_algorithm,
unregister_algorithm,
@@ -51,6 +52,7 @@ __all__ = [
"get_unverified_header",
"register_algorithm",
"unregister_algorithm",
+ "get_algorithm_by_name",
# Exceptions
"DecodeError",
"ExpiredSignatureError",
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index 16ec846..f8c8048 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -73,6 +73,23 @@ class PyJWS:
"""
return list(self._valid_algs)
+ def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
+ """
+ For a given string name, return the matching Algorithm object.
+
+ Example usage:
+
+ >>> jws_obj.get_algorithm_by_name("RS256")
+ """
+ try:
+ return self._algorithms[alg_name]
+ except KeyError as e:
+ if not has_crypto and alg_name in requires_cryptography:
+ raise NotImplementedError(
+ f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?"
+ ) from e
+ raise NotImplementedError("Algorithm not supported") from e
+
def encode(
self,
payload: bytes,
@@ -84,21 +101,21 @@ class PyJWS:
) -> str:
segments = []
- if algorithm is None:
- algorithm = "none"
+ # declare a new var to narrow the type for type checkers
+ algorithm_: str = algorithm if algorithm is not None else "none"
# Prefer headers values if present to function parameters.
if headers:
headers_alg = headers.get("alg")
if headers_alg:
- algorithm = headers["alg"]
+ algorithm_ = headers["alg"]
headers_b64 = headers.get("b64")
if headers_b64 is False:
is_payload_detached = True
# Header
- header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any]
+ header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any]
if headers:
self._validate_headers(headers)
@@ -128,17 +145,9 @@ class PyJWS:
# Segments
signing_input = b".".join(segments)
- try:
- alg_obj = self._algorithms[algorithm]
- key = alg_obj.prepare_key(key)
- signature = alg_obj.sign(signing_input, key)
-
- except KeyError as e:
- if not has_crypto and algorithm in requires_cryptography:
- raise NotImplementedError(
- f"Algorithm '{algorithm}' could not be found. Do you have cryptography installed?"
- ) from e
- raise NotImplementedError("Algorithm not supported") from e
+ alg_obj = self.get_algorithm_by_name(algorithm_)
+ key = alg_obj.prepare_key(key)
+ signature = alg_obj.sign(signing_input, key)
segments.append(base64url_encode(signature))
@@ -262,14 +271,13 @@ class PyJWS:
raise InvalidAlgorithmError("The specified alg value is not allowed")
try:
- alg_obj = self._algorithms[alg]
- key = alg_obj.prepare_key(key)
-
- if not alg_obj.verify(signing_input, key, signature):
- raise InvalidSignatureError("Signature verification failed")
-
- except KeyError as e:
+ alg_obj = self.get_algorithm_by_name(alg)
+ except NotImplementedError as e:
raise InvalidAlgorithmError("Algorithm not supported") from e
+ key = alg_obj.prepare_key(key)
+
+ if not alg_obj.verify(signing_input, key, signature):
+ raise InvalidSignatureError("Signature verification failed")
def _validate_headers(self, headers):
if "kid" in headers:
@@ -286,4 +294,5 @@ decode_complete = _jws_global_obj.decode_complete
decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm
+get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name
get_unverified_header = _jws_global_obj.get_unverified_header