summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephen Rosen <sirosen@globus.org>2022-11-02 06:01:52 -0500
committerGitHub <noreply@github.com>2022-11-02 17:01:52 +0600
commit00cd759d86aae24176ead7bdbed273a07532443e (patch)
tree34f5fd7e5d0c55299fc5c8685fb558aeebaf9f4a
parent345549567dbb58fd7bf901392cf6b1a626f36e24 (diff)
downloadpyjwt-00cd759d86aae24176ead7bdbed273a07532443e.tar.gz
Add `Algorithm.compute_hash_digest` and use it to implement at_hash validation example (#775)
* Add compute_hash_digest to Algorithm objects `Algorithm.compute_hash_digest` is defined as a method which inspects the object to see that it has the requisite attributes, `hash_alg`. If `hash_alg` is not set, then the method raises a NotImplementedError. This applies to classes like NoneAlgorithm. If `hash_alg` is set, then it is checked for ``` has_crypto # is cryptography available? and isinstance(hash_alg, type) and issubclass(hash_alg, hashes.HashAlgorithm) ``` to see which API for computing a digest is appropriate -- `hashlib` vs `cryptography.hazmat.primitives.hashes`. These checks could be avoided at runtime if it were necessary to optimize further (e.g. attach compute_hash_digest methods to classes with a class decorator) but this is not clearly a worthwhile optimization. Such perf tuning is intentionally omitted for now. * Add doc example of OIDC login flow The goal of this doc example is to demonstrate usage of `get_algorithm_by_name` and `compute_hash_digest` for the purpose of `at_hash` validation. It is not meant to be a "guaranteed correct" and spec-compliant example. closes #314
-rw-r--r--CHANGELOG.rst4
-rw-r--r--docs/usage.rst71
-rw-r--r--jwt/algorithms.py23
-rw-r--r--tests/test_algorithms.py23
4 files changed, 121 insertions, 0 deletions
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 8ec7ede..4d56207 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -18,6 +18,10 @@ Fixed
Added
~~~~~
+- Add ``compute_hash_digest`` as a method of ``Algorithm`` objects, which uses
+ the underlying hash algorithm to compute a digest. If there is no appropriate
+ hash algorithm, a ``NotImplementedError`` will be raised
+
`v2.6.0 <https://github.com/jpadilla/pyjwt/compare/2.5.0...2.6.0>`__
-----------------------------------------------------------------------
diff --git a/docs/usage.rst b/docs/usage.rst
index 91d9679..a85fa18 100644
--- a/docs/usage.rst
+++ b/docs/usage.rst
@@ -297,3 +297,74 @@ Retrieve RSA signing keys from a JWKS endpoint
... )
>>> print(data)
{'iss': 'https://dev-87evx9ru.auth0.com/', 'sub': 'aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC@clients', 'aud': 'https://expenses-api', 'iat': 1572006954, 'exp': 1572006964, 'azp': 'aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC', 'gty': 'client-credentials'}
+
+OIDC Login Flow
+---------------
+
+The following usage demonstrates an OIDC login flow using pyjwt. Further
+reading about the OIDC spec is recommended for implementers.
+
+In particular, this demonstrates validation of the ``at_hash`` claim.
+This claim relies on data from outside of the the JWT for validation. Methods
+are provided which support computation and validation of this claim, but it
+is not built into pyjwt.
+
+.. code-block:: python
+
+ import base64
+ import jwt
+ import requests
+
+
+ # Part 1: setup
+ # get the OIDC config and JWKs to use
+
+ # in OIDC, you must know your client_id (this is the OAuth 2.0 client_id)
+ client_id = ...
+
+ # example of fetching data from your OIDC server
+ # see: https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
+ oidc_server = ...
+ oidc_config = requests.get(
+ f"https://{oidc_server}/.well-known/openid-configuration"
+ ).json()
+ signing_algos = oidc_config["id_token_signing_alg_values_supported"]
+
+ # setup a PyJWKClient to get the appropriate signing key
+ jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"])
+
+
+ # Part 2: login / authorization
+ # when a user completes an OIDC login flow, there will be a well-formed
+ # response object to parse/handle
+
+ # data from the login flow
+ # see: https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
+ token_response = ...
+ id_token = token_response["id_token"]
+ access_token = token_response["access_token"]
+
+
+ # Part 3: decode and validate at_hash
+ # after the login is complete, the id_token needs to be decoded
+ # this is the stage at which an OIDC client must verify the at_hash
+
+ # get signing_key from id_token
+ signing_key = jwks_client.get_signing_key_from_jwt(id_token)
+
+ # now, decode_complete to get payload + header
+ data = jwt.decode_complete(
+ id_token,
+ key=signing_key.key,
+ algorithms=signing_algos,
+ audience=client_id,
+ )
+ payload, header = data["payload"], data["header"]
+
+ # get the pyjwt algorithm object
+ alg_obj = jwt.get_algorithm_by_name(header["alg"])
+
+ # compute at_hash, then validate / assert
+ digest = alg_obj.compute_hash_digest(access_token)
+ at_hash = base64.urlsafe_b64encode(digest[: (len(digest) // 2)]).rstrip("=")
+ assert at_hash == payload["at_hash"]
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 93fadf4..4fae441 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -18,6 +18,7 @@ from .utils import (
try:
import cryptography.exceptions
from cryptography.exceptions import InvalidSignature
+ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, padding
from cryptography.hazmat.primitives.asymmetric.ec import (
@@ -111,6 +112,28 @@ class Algorithm:
The interface for an algorithm used to sign and verify tokens.
"""
+ def compute_hash_digest(self, bytestr: bytes) -> bytes:
+ """
+ Compute a hash digest using the specified algorithm's hash algorithm.
+
+ If there is no hash algorithm, raises a NotImplementedError.
+ """
+ # lookup self.hash_alg if defined in a way that mypy can understand
+ hash_alg = getattr(self, "hash_alg", None)
+ if hash_alg is None:
+ raise NotImplementedError
+
+ if (
+ has_crypto
+ and isinstance(hash_alg, type)
+ and issubclass(hash_alg, hashes.HashAlgorithm)
+ ):
+ digest = hashes.Hash(hash_alg(), backend=default_backend())
+ digest.update(bytestr)
+ return digest.finalize()
+ else:
+ return hash_alg(bytestr).digest()
+
def prepare_key(self, key):
"""
Performs necessary validation and conversions on the key and returns
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index 538078a..894ce28 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -45,6 +45,12 @@ class TestAlgorithms:
with pytest.raises(NotImplementedError):
algo.to_jwk("value")
+ def test_algorithm_should_throw_exception_if_compute_hash_digest_not_impl(self):
+ algo = Algorithm()
+
+ with pytest.raises(NotImplementedError):
+ algo.compute_hash_digest(b"value")
+
def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
algo = NoneAlgorithm()
@@ -1054,3 +1060,20 @@ class TestOKPAlgorithms:
signature_2 = algo.sign(b"Hello World!", priv_key_2)
assert algo.verify(b"Hello World!", pub_key_2, signature_1)
assert algo.verify(b"Hello World!", pub_key_2, signature_2)
+
+ @crypto_required
+ def test_rsa_can_compute_digest(self):
+ # this is the well-known sha256 hash of "foo"
+ foo_hash = base64.b64decode(b"LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=")
+
+ algo = RSAAlgorithm(RSAAlgorithm.SHA256)
+ computed_hash = algo.compute_hash_digest(b"foo")
+ assert computed_hash == foo_hash
+
+ def test_hmac_can_compute_digest(self):
+ # this is the well-known sha256 hash of "foo"
+ foo_hash = base64.b64decode(b"LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=")
+
+ algo = HMACAlgorithm(HMACAlgorithm.SHA256)
+ computed_hash = algo.compute_hash_digest(b"foo")
+ assert computed_hash == foo_hash