summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHaoyu(Jerry) Wu <wuhaoyujerry@users.noreply.github.com>2022-08-01 09:37:48 -0700
committerGitHub <noreply@github.com>2022-08-01 22:37:48 +0600
commitfc5b94eb3575254caba599218246616c75fecdc7 (patch)
treecfd8e75fe2eff1360221e9f22173bae9bed38ff6
parentae3da7469ff8c28b726e082cd671997e09b19d55 (diff)
downloadpyjwt-fc5b94eb3575254caba599218246616c75fecdc7.tar.gz
Add cacheing functionality for JWK set (#781)
* Initial implementation of ttl jwk set cache (cherry picked from commit 479a7c124d63113a2190bd48972cc19172215096) * Add unit test for jwk set cache * Fix failed unit test * Disable cache signing key by default * Add a negative unit test for get_jwk_set * Add functionality to force refresh the jwk set cache when no matching signing key can be found from the cache * Add unit test for refresh cache * Add unit test to unset cache when the network call throws error * fix naming typo * Update unit test naming * Update comment * Add check for lifespan * Update comments for get_signing_key * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix ci error * Add type declaration to fix CI error * Add more unit tests to improve coverage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Try to increase test coverage to 100% Co-authored-by: Jerry Wu <hawu@roku.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-rw-r--r--jwt/api_jwk.py13
-rw-r--r--jwt/jwk_set_cache.py32
-rw-r--r--jwt/jwks_client.py82
-rw-r--r--tests/test_jwks_client.py169
4 files changed, 263 insertions, 33 deletions
diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py
index bbd8a9e..aa3dd32 100644
--- a/jwt/api_jwk.py
+++ b/jwt/api_jwk.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import json
+import time
from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
@@ -110,3 +111,15 @@ class PyJWKSet:
if key.key_id == kid:
return key
raise KeyError(f"keyset has no key for kid: {kid}")
+
+
+class PyJWTSetWithTimestamp:
+ def __init__(self, jwk_set: PyJWKSet):
+ self.jwk_set = jwk_set
+ self.timestamp = time.monotonic()
+
+ def get_jwk_set(self):
+ return self.jwk_set
+
+ def get_timestamp(self):
+ return self.timestamp
diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py
new file mode 100644
index 0000000..e8c2a7e
--- /dev/null
+++ b/jwt/jwk_set_cache.py
@@ -0,0 +1,32 @@
+import time
+from typing import Optional
+
+from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp
+
+
+class JWKSetCache:
+ def __init__(self, lifespan: int):
+ self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
+ self.lifespan = lifespan
+
+ def put(self, jwk_set: PyJWKSet):
+ if jwk_set is not None:
+ self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
+ else:
+ # clear cache
+ self.jwk_set_with_timestamp = None
+
+ def get(self) -> Optional[PyJWKSet]:
+ if self.jwk_set_with_timestamp is None or self.is_expired():
+ return None
+
+ return self.jwk_set_with_timestamp.get_jwk_set()
+
+ def is_expired(self) -> bool:
+
+ return (
+ self.jwk_set_with_timestamp is not None
+ and self.lifespan > -1
+ and time.monotonic()
+ > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan
+ )
diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py
index 767b717..b4e9800 100644
--- a/jwt/jwks_client.py
+++ b/jwt/jwks_client.py
@@ -1,31 +1,68 @@
import json
import urllib.request
from functools import lru_cache
-from typing import Any, List
+from typing import Any, List, Optional
+from urllib.error import URLError
from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientError
+from .jwk_set_cache import JWKSetCache
class PyJWKClient:
- def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16):
+ def __init__(
+ self,
+ uri: str,
+ cache_keys: bool = False,
+ max_cached_keys: int = 16,
+ cache_jwk_set: bool = True,
+ lifespan: int = 300,
+ ):
self.uri = uri
+ self.jwk_set_cache: Optional[JWKSetCache] = None
+
+ if cache_jwk_set:
+ # Init jwt set cache with default or given lifespan.
+ # Default lifespan is 300 seconds (5 minutes).
+ if lifespan <= 0:
+ raise PyJWKClientError(
+ f'Lifespan must be greater than 0, the input is "{lifespan}"'
+ )
+ self.jwk_set_cache = JWKSetCache(lifespan)
+ else:
+ self.jwk_set_cache = None
+
if cache_keys:
# Cache signing keys
# Ignore mypy (https://github.com/python/mypy/issues/2427)
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore
def fetch_data(self) -> Any:
- with urllib.request.urlopen(self.uri) as response:
- return json.load(response)
+ jwk_set: Any = None
+ try:
+ with urllib.request.urlopen(self.uri) as response:
+ jwk_set = json.load(response)
+ except URLError as e:
+ raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"')
+ else:
+ return jwk_set
+ finally:
+ if self.jwk_set_cache is not None:
+ self.jwk_set_cache.put(jwk_set)
+
+ def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
+ data = None
+ if self.jwk_set_cache is not None and not refresh:
+ data = self.jwk_set_cache.get()
+
+ if data is None:
+ data = self.fetch_data()
- def get_jwk_set(self) -> PyJWKSet:
- data = self.fetch_data()
return PyJWKSet.from_dict(data)
- def get_signing_keys(self) -> List[PyJWK]:
- jwk_set = self.get_jwk_set()
+ def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
+ jwk_set = self.get_jwk_set(refresh)
signing_keys = [
jwk_set_key
for jwk_set_key in jwk_set.keys
@@ -39,17 +76,17 @@ class PyJWKClient:
def get_signing_key(self, kid: str) -> PyJWK:
signing_keys = self.get_signing_keys()
- signing_key = None
-
- for key in signing_keys:
- if key.key_id == kid:
- signing_key = key
- break
+ signing_key = self.match_kid(signing_keys, kid)
if not signing_key:
- raise PyJWKClientError(
- f'Unable to find a signing key that matches: "{kid}"'
- )
+ # If no matching signing key from the jwk set, refresh the jwk set and try again.
+ signing_keys = self.get_signing_keys(refresh=True)
+ signing_key = self.match_kid(signing_keys, kid)
+
+ if not signing_key:
+ raise PyJWKClientError(
+ f'Unable to find a signing key that matches: "{kid}"'
+ )
return signing_key
@@ -57,3 +94,14 @@ class PyJWKClient:
unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"]
return self.get_signing_key(header.get("kid"))
+
+ @staticmethod
+ def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]:
+ signing_key = None
+
+ for key in signing_keys:
+ if key.key_id == kid:
+ signing_key = key
+ break
+
+ return signing_key
diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py
index 3e42da1..c95dfcc 100644
--- a/tests/test_jwks_client.py
+++ b/tests/test_jwks_client.py
@@ -1,6 +1,8 @@
import contextlib
import json
+import time
from unittest import mock
+from urllib.error import URLError
import pytest
@@ -11,7 +13,7 @@ from jwt.exceptions import PyJWKClientError
from .utils import crypto_required
-RESPONSE_DATA = {
+RESPONSE_DATA_WITH_MATCHING_KID = {
"keys": [
{
"alg": "RS256",
@@ -28,9 +30,22 @@ RESPONSE_DATA = {
]
}
+RESPONSE_DATA_NO_MATCHING_KID = {
+ "keys": [
+ {
+ "alg": "RS256",
+ "kty": "RSA",
+ "use": "sig",
+ "n": "39SJ39VgrQ0qMNK74CaueUBlyYsUyuA7yWlHYZ-jAj6tlFKugEVUTBUVbhGF44uOr99iL_cwmr-srqQDEi-jFHdkS6WFkYyZ03oyyx5dtBMtzrXPieFipSGfQ5EGUGloaKDjL-Ry9tiLnysH2VVWZ5WDDN-DGHxuCOWWjiBNcTmGfnj5_NvRHNUh2iTLuiJpHbGcPzWc5-lc4r-_ehw9EFfp2XsxE9xvtbMZ4SouJCiv9xnrnhe2bdpWuu34hXZCrQwE8DjRY3UR8LjyMxHHPLzX2LWNMHjfN3nAZMteS-Ok11VYDFI-4qCCVGo_WesBCAeqCjPLRyZoV27x1YGsUQ",
+ "e": "AQAB",
+ "kid": "MLYHNMMhwCNXw9roHIILFsK4nLs=",
+ }
+ ]
+}
+
@contextlib.contextmanager
-def mocked_response(data):
+def mocked_success_response(data):
with mock.patch("urllib.request.urlopen") as urlopen_mock:
response = mock.Mock()
response.__enter__ = mock.Mock(return_value=response)
@@ -40,12 +55,35 @@ def mocked_response(data):
yield urlopen_mock
+@contextlib.contextmanager
+def mocked_failed_response():
+ with mock.patch("urllib.request.urlopen") as urlopen_mock:
+ urlopen_mock.side_effect = URLError("Fail to process the request.")
+ yield urlopen_mock
+
+
+@contextlib.contextmanager
+def mocked_first_call_wrong_kid_second_call_correct_kid(
+ response_data_one, response_data_two
+):
+ with mock.patch("urllib.request.urlopen") as urlopen_mock:
+ response = mock.Mock()
+ response.__enter__ = mock.Mock(return_value=response)
+ response.__exit__ = mock.Mock()
+ response.read.side_effect = [
+ json.dumps(response_data_one),
+ json.dumps(response_data_two),
+ ]
+ urlopen_mock.return_value = response
+ yield urlopen_mock
+
+
@crypto_required
class TestPyJWKClient:
def test_get_jwk_set(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
- with mocked_response(RESPONSE_DATA):
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client = PyJWKClient(url)
jwk_set = jwks_client.get_jwk_set()
@@ -54,7 +92,7 @@ class TestPyJWKClient:
def test_get_signing_keys(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
- with mocked_response(RESPONSE_DATA):
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client = PyJWKClient(url)
signing_keys = jwks_client.get_signing_keys()
@@ -64,11 +102,11 @@ class TestPyJWKClient:
def test_get_signing_keys_if_no_use_provided(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
- mocked_key = RESPONSE_DATA["keys"][0].copy()
+ mocked_key = RESPONSE_DATA_WITH_MATCHING_KID["keys"][0].copy()
del mocked_key["use"]
response = {"keys": [mocked_key]}
- with mocked_response(response):
+ with mocked_success_response(response):
jwks_client = PyJWKClient(url)
signing_keys = jwks_client.get_signing_keys()
@@ -78,10 +116,10 @@ class TestPyJWKClient:
def test_get_signing_keys_raises_if_none_found(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
- mocked_key = RESPONSE_DATA["keys"][0].copy()
+ mocked_key = RESPONSE_DATA_WITH_MATCHING_KID["keys"][0].copy()
mocked_key["use"] = "enc"
response = {"keys": [mocked_key]}
- with mocked_response(response):
+ with mocked_success_response(response):
jwks_client = PyJWKClient(url)
with pytest.raises(PyJWKClientError) as exc:
@@ -93,7 +131,7 @@ class TestPyJWKClient:
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw"
- with mocked_response(RESPONSE_DATA):
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client = PyJWKClient(url)
signing_key = jwks_client.get_signing_key(kid)
@@ -106,14 +144,14 @@ class TestPyJWKClient:
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw"
- jwks_client = PyJWKClient(url)
+ jwks_client = PyJWKClient(url, cache_keys=True)
- with mocked_response(RESPONSE_DATA):
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client.get_signing_key(kid)
# mocked_response does not allow urllib.request.urlopen to be called twice
# so a second mock is needed
- with mocked_response(RESPONSE_DATA) as repeated_call:
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call:
jwks_client.get_signing_key(kid)
assert repeated_call.call_count == 0
@@ -122,14 +160,14 @@ class TestPyJWKClient:
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw"
- jwks_client = PyJWKClient(url, cache_keys=False)
+ jwks_client = PyJWKClient(url, cache_jwk_set=False)
- with mocked_response(RESPONSE_DATA):
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client.get_signing_key(kid)
# mocked_response does not allow urllib.request.urlopen to be called twice
# so a second mock is needed
- with mocked_response(RESPONSE_DATA) as repeated_call:
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call:
jwks_client.get_signing_key(kid)
assert repeated_call.call_count == 1
@@ -138,7 +176,7 @@ class TestPyJWKClient:
token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik5FRTFRVVJCT1RNNE16STVSa0ZETlRZeE9UVTFNRGcyT0Rnd1EwVXpNVGsxUWpZeVJrUkZRdyJ9.eyJpc3MiOiJodHRwczovL2Rldi04N2V2eDlydS5hdXRoMC5jb20vIiwic3ViIjoiYVc0Q2NhNzl4UmVMV1V6MGFFMkg2a0QwTzNjWEJWdENAY2xpZW50cyIsImF1ZCI6Imh0dHBzOi8vZXhwZW5zZXMtYXBpIiwiaWF0IjoxNTcyMDA2OTU0LCJleHAiOjE1NzIwMDY5NjQsImF6cCI6ImFXNENjYTc5eFJlTFdVejBhRTJINmtEME8zY1hCVnRDIiwiZ3R5IjoiY2xpZW50LWNyZWRlbnRpYWxzIn0.PUxE7xn52aTCohGiWoSdMBZGiYAHwE5FYie0Y1qUT68IHSTXwXVd6hn02HTah6epvHHVKA2FqcFZ4GGv5VTHEvYpeggiiZMgbxFrmTEY0csL6VNkX1eaJGcuehwQCRBKRLL3zKmA5IKGy5GeUnIbpPHLHDxr-GXvgFzsdsyWlVQvPX2xjeaQ217r2PtxDeqjlf66UYl6oY6AqNS8DH3iryCvIfCcybRZkc_hdy-6ZMoKT6Piijvk_aXdm7-QQqKJFHLuEqrVSOuBqqiNfVrG27QzAPuPOxvfXTVLXL2jek5meH6n-VWgrBdoMFH93QEszEDowDAEhQPHVs0xj7SIzA"
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
- with mocked_response(RESPONSE_DATA):
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client = PyJWKClient(url)
signing_key = jwks_client.get_signing_key_from_jwt(token)
@@ -159,3 +197,102 @@ class TestPyJWKClient:
"azp": "aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC",
"gty": "client-credentials",
}
+
+ def test_get_jwk_set_caches_result(self):
+ url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
+
+ jwks_client = PyJWKClient(url)
+ assert jwks_client.jwk_set_cache is not None
+
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
+ jwks_client.get_jwk_set()
+
+ # mocked_response does not allow urllib.request.urlopen to be called twice
+ # so a second mock is needed
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call:
+ jwks_client.get_jwk_set()
+
+ assert repeated_call.call_count == 0
+
+ def test_get_jwt_set_cache_expired_result(self):
+ url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
+
+ jwks_client = PyJWKClient(url, lifespan=1)
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
+ jwks_client.get_jwk_set()
+
+ time.sleep(2)
+
+ # mocked_response does not allow urllib.request.urlopen to be called twice
+ # so a second mock is needed
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call:
+ jwks_client.get_jwk_set()
+
+ assert repeated_call.call_count == 1
+
+ def test_get_jwt_set_cache_disabled(self):
+ url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
+
+ jwks_client = PyJWKClient(url, cache_jwk_set=False)
+ assert jwks_client.jwk_set_cache is None
+
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
+ jwks_client.get_jwk_set()
+
+ assert jwks_client.jwk_set_cache is None
+
+ time.sleep(2)
+
+ # mocked_response does not allow urllib.request.urlopen to be called twice
+ # so a second mock is needed
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call:
+ jwks_client.get_jwk_set()
+
+ assert repeated_call.call_count == 1
+
+ def test_get_jwt_set_failed_request_should_clear_cache(self):
+ url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
+
+ jwks_client = PyJWKClient(url)
+ with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
+ jwks_client.get_jwk_set()
+
+ with pytest.raises(PyJWKClientError):
+ with mocked_failed_response():
+ jwks_client.get_jwk_set(refresh=True)
+
+ assert jwks_client.jwk_set_cache is None
+
+ def test_get_jwt_set_refresh_cache(self):
+ url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
+ jwks_client = PyJWKClient(url)
+
+ kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw"
+
+ # The first call will return response with no matching kid,
+ # the function should make another call to try to refresh the cache.
+ with mocked_first_call_wrong_kid_second_call_correct_kid(
+ RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_WITH_MATCHING_KID
+ ) as call_data:
+ jwks_client.get_signing_key(kid)
+
+ assert call_data.call_count == 2
+
+ def test_get_jwt_set_no_matching_kid_after_second_attempt(self):
+ url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
+ jwks_client = PyJWKClient(url)
+
+ kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw"
+
+ with pytest.raises(PyJWKClientError):
+ with mocked_first_call_wrong_kid_second_call_correct_kid(
+ RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_NO_MATCHING_KID
+ ):
+ jwks_client.get_signing_key(kid)
+
+ def test_get_jwt_set_invalid_lifespan(self):
+ url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
+
+ with pytest.raises(PyJWKClientError):
+ jwks_client = PyJWKClient(url, lifespan=-1)
+ assert jwks_client is None