diff options
-rw-r--r-- | jwt/jwks_client.py | 6 | ||||
-rw-r--r-- | tests/test_jwks_client.py | 17 |
2 files changed, 21 insertions, 2 deletions
diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index aa33bb3..e237186 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -19,12 +19,14 @@ class PyJWKClient: cache_jwk_set: bool = True, lifespan: int = 300, headers: Optional[Dict[str, Any]] = None, + timeout: int = 30, ): if headers is None: headers = {} self.uri = uri self.jwk_set_cache: Optional[JWKSetCache] = None self.headers = headers + self.timeout = timeout if cache_jwk_set: # Init jwt set cache with default or given lifespan. @@ -46,9 +48,9 @@ class PyJWKClient: jwk_set: Any = None try: r = urllib.request.Request(url=self.uri, headers=self.headers) - with urllib.request.urlopen(r) as response: + with urllib.request.urlopen(r, timeout=self.timeout) as response: jwk_set = json.load(response) - except URLError as e: + except (URLError, TimeoutError) as e: raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') else: return jwk_set diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index 5029fe1..5886c6a 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -78,6 +78,13 @@ def mocked_first_call_wrong_kid_second_call_correct_kid( yield urlopen_mock +@contextlib.contextmanager +def mocked_timeout(): + with mock.patch("urllib.request.urlopen") as urlopen_mock: + urlopen_mock.side_effect = TimeoutError("timed out") + yield urlopen_mock + + @crypto_required class TestPyJWKClient: def test_fetch_data_forwards_headers_to_correct_url(self): @@ -309,3 +316,13 @@ class TestPyJWKClient: with pytest.raises(PyJWKClientError): jwks_client = PyJWKClient(url, lifespan=-1) assert jwks_client is None + + def test_get_jwt_set_timeout(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + jwks_client = PyJWKClient(url, timeout=5) + + with pytest.raises(PyJWKClientError) as exc: + with mocked_timeout(): + jwks_client.get_jwk_set() + + assert 'Fail to fetch data from the url, err: "timed out"' in str(exc.value) |