summaryrefslogtreecommitdiff
path: root/jwt/api_jwk.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/api_jwk.py')
-rw-r--r--jwt/api_jwk.py26
1 files changed, 14 insertions, 12 deletions
diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py
index aa3dd32..ef8cce8 100644
--- a/jwt/api_jwk.py
+++ b/jwt/api_jwk.py
@@ -2,13 +2,15 @@ from __future__ import annotations
import json
import time
+from typing import Any, Optional
from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
+from .types import JWKDict
class PyJWK:
- def __init__(self, jwk_data, algorithm=None):
+ def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None:
self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data
@@ -55,29 +57,29 @@ class PyJWK:
self.key = self.Algorithm.from_jwk(self._jwk_data)
@staticmethod
- def from_dict(obj, algorithm=None):
+ def from_dict(obj: JWKDict, algorithm: Optional[str] = None) -> "PyJWK":
return PyJWK(obj, algorithm)
@staticmethod
- def from_json(data, algorithm=None):
+ def from_json(data: str, algorithm: None = None) -> "PyJWK":
obj = json.loads(data)
return PyJWK.from_dict(obj, algorithm)
@property
- def key_type(self):
+ def key_type(self) -> str:
return self._jwk_data.get("kty", None)
@property
- def key_id(self):
+ def key_id(self) -> str:
return self._jwk_data.get("kid", None)
@property
- def public_key_use(self):
+ def public_key_use(self) -> Optional[str]:
return self._jwk_data.get("use", None)
class PyJWKSet:
- def __init__(self, keys: list[dict]) -> None:
+ def __init__(self, keys: list[JWKDict]) -> None:
self.keys = []
if not keys:
@@ -97,16 +99,16 @@ class PyJWKSet:
raise PyJWKSetError("The JWK Set did not contain any usable keys")
@staticmethod
- def from_dict(obj):
+ def from_dict(obj: dict[str, Any]) -> "PyJWKSet":
keys = obj.get("keys", [])
return PyJWKSet(keys)
@staticmethod
- def from_json(data):
+ def from_json(data: str) -> "PyJWKSet":
obj = json.loads(data)
return PyJWKSet.from_dict(obj)
- def __getitem__(self, kid):
+ def __getitem__(self, kid: str) -> "PyJWK":
for key in self.keys:
if key.key_id == kid:
return key
@@ -118,8 +120,8 @@ class PyJWTSetWithTimestamp:
self.jwk_set = jwk_set
self.timestamp = time.monotonic()
- def get_jwk_set(self):
+ def get_jwk_set(self) -> PyJWKSet:
return self.jwk_set
- def get_timestamp(self):
+ def get_timestamp(self) -> float:
return self.timestamp