import time from typing import Optional from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp class JWKSetCache: def __init__(self, lifespan: int) -> None: self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None self.lifespan = lifespan def put(self, jwk_set: PyJWKSet) -> None: if jwk_set is not None: self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set) else: # clear cache self.jwk_set_with_timestamp = None def get(self) -> Optional[PyJWKSet]: if self.jwk_set_with_timestamp is None or self.is_expired(): return None return self.jwk_set_with_timestamp.get_jwk_set() def is_expired(self) -> bool: return ( self.jwk_set_with_timestamp is not None and self.lifespan > -1 and time.monotonic() > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan )