diff options
author | Haoyu(Jerry) Wu <wuhaoyujerry@users.noreply.github.com> | 2022-08-01 09:37:48 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-01 22:37:48 +0600 |
commit | fc5b94eb3575254caba599218246616c75fecdc7 (patch) | |
tree | cfd8e75fe2eff1360221e9f22173bae9bed38ff6 /jwt | |
parent | ae3da7469ff8c28b726e082cd671997e09b19d55 (diff) | |
download | pyjwt-fc5b94eb3575254caba599218246616c75fecdc7.tar.gz |
Add cacheing functionality for JWK set (#781)
* Initial implementation of ttl jwk set cache
(cherry picked from commit 479a7c124d63113a2190bd48972cc19172215096)
* Add unit test for jwk set cache
* Fix failed unit test
* Disable cache signing key by default
* Add a negative unit test for get_jwk_set
* Add functionality to force refresh the jwk set cache when no matching signing key can be found from the cache
* Add unit test for refresh cache
* Add unit test to unset cache when the network call throws error
* fix naming typo
* Update unit test naming
* Update comment
* Add check for lifespan
* Update comments for get_signing_key
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Fix ci error
* Add type declaration to fix CI error
* Add more unit tests to improve coverage
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Try to increase test coverage to 100%
Co-authored-by: Jerry Wu <hawu@roku.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Diffstat (limited to 'jwt')
-rw-r--r-- | jwt/api_jwk.py | 13 | ||||
-rw-r--r-- | jwt/jwk_set_cache.py | 32 | ||||
-rw-r--r-- | jwt/jwks_client.py | 82 |
3 files changed, 110 insertions, 17 deletions
diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index bbd8a9e..aa3dd32 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import time from .algorithms import get_default_algorithms from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError @@ -110,3 +111,15 @@ class PyJWKSet: if key.key_id == kid: return key raise KeyError(f"keyset has no key for kid: {kid}") + + +class PyJWTSetWithTimestamp: + def __init__(self, jwk_set: PyJWKSet): + self.jwk_set = jwk_set + self.timestamp = time.monotonic() + + def get_jwk_set(self): + return self.jwk_set + + def get_timestamp(self): + return self.timestamp diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py new file mode 100644 index 0000000..e8c2a7e --- /dev/null +++ b/jwt/jwk_set_cache.py @@ -0,0 +1,32 @@ +import time +from typing import Optional + +from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp + + +class JWKSetCache: + def __init__(self, lifespan: int): + self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None + self.lifespan = lifespan + + def put(self, jwk_set: PyJWKSet): + if jwk_set is not None: + self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set) + else: + # clear cache + self.jwk_set_with_timestamp = None + + def get(self) -> Optional[PyJWKSet]: + if self.jwk_set_with_timestamp is None or self.is_expired(): + return None + + return self.jwk_set_with_timestamp.get_jwk_set() + + def is_expired(self) -> bool: + + return ( + self.jwk_set_with_timestamp is not None + and self.lifespan > -1 + and time.monotonic() + > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan + ) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 767b717..b4e9800 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -1,31 +1,68 @@ import json import urllib.request from functools import lru_cache -from typing import Any, List +from typing import Any, List, Optional +from urllib.error import URLError from .api_jwk import PyJWK, PyJWKSet from .api_jwt import decode_complete as decode_token from .exceptions import PyJWKClientError +from .jwk_set_cache import JWKSetCache class PyJWKClient: - def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16): + def __init__( + self, + uri: str, + cache_keys: bool = False, + max_cached_keys: int = 16, + cache_jwk_set: bool = True, + lifespan: int = 300, + ): self.uri = uri + self.jwk_set_cache: Optional[JWKSetCache] = None + + if cache_jwk_set: + # Init jwt set cache with default or given lifespan. + # Default lifespan is 300 seconds (5 minutes). + if lifespan <= 0: + raise PyJWKClientError( + f'Lifespan must be greater than 0, the input is "{lifespan}"' + ) + self.jwk_set_cache = JWKSetCache(lifespan) + else: + self.jwk_set_cache = None + if cache_keys: # Cache signing keys # Ignore mypy (https://github.com/python/mypy/issues/2427) self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore def fetch_data(self) -> Any: - with urllib.request.urlopen(self.uri) as response: - return json.load(response) + jwk_set: Any = None + try: + with urllib.request.urlopen(self.uri) as response: + jwk_set = json.load(response) + except URLError as e: + raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') + else: + return jwk_set + finally: + if self.jwk_set_cache is not None: + self.jwk_set_cache.put(jwk_set) + + def get_jwk_set(self, refresh: bool = False) -> PyJWKSet: + data = None + if self.jwk_set_cache is not None and not refresh: + data = self.jwk_set_cache.get() + + if data is None: + data = self.fetch_data() - def get_jwk_set(self) -> PyJWKSet: - data = self.fetch_data() return PyJWKSet.from_dict(data) - def get_signing_keys(self) -> List[PyJWK]: - jwk_set = self.get_jwk_set() + def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: + jwk_set = self.get_jwk_set(refresh) signing_keys = [ jwk_set_key for jwk_set_key in jwk_set.keys @@ -39,17 +76,17 @@ class PyJWKClient: def get_signing_key(self, kid: str) -> PyJWK: signing_keys = self.get_signing_keys() - signing_key = None - - for key in signing_keys: - if key.key_id == kid: - signing_key = key - break + signing_key = self.match_kid(signing_keys, kid) if not signing_key: - raise PyJWKClientError( - f'Unable to find a signing key that matches: "{kid}"' - ) + # If no matching signing key from the jwk set, refresh the jwk set and try again. + signing_keys = self.get_signing_keys(refresh=True) + signing_key = self.match_kid(signing_keys, kid) + + if not signing_key: + raise PyJWKClientError( + f'Unable to find a signing key that matches: "{kid}"' + ) return signing_key @@ -57,3 +94,14 @@ class PyJWKClient: unverified = decode_token(token, options={"verify_signature": False}) header = unverified["header"] return self.get_signing_key(header.get("kid")) + + @staticmethod + def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]: + signing_key = None + + for key in signing_keys: + if key.key_id == kid: + signing_key = key + break + + return signing_key |