diff options
Diffstat (limited to 'jwt/api_jwk.py')
-rw-r--r-- | jwt/api_jwk.py | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index aa3dd32..ef8cce8 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -2,13 +2,15 @@ from __future__ import annotations import json import time +from typing import Any, Optional from .algorithms import get_default_algorithms from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError +from .types import JWKDict class PyJWK: - def __init__(self, jwk_data, algorithm=None): + def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None: self._algorithms = get_default_algorithms() self._jwk_data = jwk_data @@ -55,29 +57,29 @@ class PyJWK: self.key = self.Algorithm.from_jwk(self._jwk_data) @staticmethod - def from_dict(obj, algorithm=None): + def from_dict(obj: JWKDict, algorithm: Optional[str] = None) -> "PyJWK": return PyJWK(obj, algorithm) @staticmethod - def from_json(data, algorithm=None): + def from_json(data: str, algorithm: None = None) -> "PyJWK": obj = json.loads(data) return PyJWK.from_dict(obj, algorithm) @property - def key_type(self): + def key_type(self) -> str: return self._jwk_data.get("kty", None) @property - def key_id(self): + def key_id(self) -> str: return self._jwk_data.get("kid", None) @property - def public_key_use(self): + def public_key_use(self) -> Optional[str]: return self._jwk_data.get("use", None) class PyJWKSet: - def __init__(self, keys: list[dict]) -> None: + def __init__(self, keys: list[JWKDict]) -> None: self.keys = [] if not keys: @@ -97,16 +99,16 @@ class PyJWKSet: raise PyJWKSetError("The JWK Set did not contain any usable keys") @staticmethod - def from_dict(obj): + def from_dict(obj: dict[str, Any]) -> "PyJWKSet": keys = obj.get("keys", []) return PyJWKSet(keys) @staticmethod - def from_json(data): + def from_json(data: str) -> "PyJWKSet": obj = json.loads(data) return PyJWKSet.from_dict(obj) - def __getitem__(self, kid): + def __getitem__(self, kid: str) -> "PyJWK": for key in self.keys: if key.key_id == kid: return key @@ -118,8 +120,8 @@ class PyJWTSetWithTimestamp: self.jwk_set = jwk_set self.timestamp = time.monotonic() - def get_jwk_set(self): + def get_jwk_set(self) -> PyJWKSet: return self.jwk_set - def get_timestamp(self): + def get_timestamp(self) -> float: return self.timestamp |