summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2022-12-10 17:44:04 +0200
committerGitHub <noreply@github.com>2022-12-10 21:44:04 +0600
commit2306f68b87c5f7eecbd838db900a2d3685b81b3c (patch)
treeeba477f54ab693600e65cf325842fc0c15436a1b
parentfb9b3119449353dca6c03be27f32e482852473c3 (diff)
downloadpyjwt-2306f68b87c5f7eecbd838db900a2d3685b81b3c.tar.gz
Make mypy configuration stricter and improve typing (#830)
* PyJWS._verify_signature: raise early KeyError if header is missing alg * Make Mypy configuration stricter * Improve typing in jwt.utils * Improve typing in jwt.help * Improve typing in jwt.exceptions * Improve typing in jwt.api_jwk * Improve typing in jwt.api_jws * Improve typing & clean up imports in jwt.algorithms * Correct JWS.decode rettype to any (payload could be something else) * Update typing in api_jwt * Improve typing in jwks_client * Improve typing in docs/conf.py * Fix (benign) mistyping in test_advisory * Fix misc type complaints in tests
-rw-r--r--docs/conf.py4
-rw-r--r--jwt/algorithms.py99
-rw-r--r--jwt/api_jwk.py26
-rw-r--r--jwt/api_jws.py34
-rw-r--r--jwt/api_jwt.py58
-rw-r--r--jwt/exceptions.py4
-rw-r--r--jwt/help.py8
-rw-r--r--jwt/jwks_client.py9
-rw-r--r--jwt/types.py3
-rw-r--r--jwt/utils.py30
-rw-r--r--pyproject.toml8
-rw-r--r--tests/test_advisory.py12
-rw-r--r--tests/test_algorithms.py6
-rw-r--r--tests/test_api_jwk.py2
14 files changed, 187 insertions, 116 deletions
diff --git a/docs/conf.py b/docs/conf.py
index 44b5103..938631f 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -4,7 +4,7 @@ import re
import sphinx_rtd_theme
-def read(*parts):
+def read(*parts) -> str:
"""
Build an absolute path from *parts* and and return the contents of the
resulting file. Assume UTF-8 encoding.
@@ -14,7 +14,7 @@ def read(*parts):
return f.read()
-def find_version(*file_paths):
+def find_version(*file_paths) -> str:
"""
Build a path from *file_paths* and search for a ``__version__``
string inside.
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 4fae441..c809f15 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,8 +1,10 @@
import hashlib
import hmac
import json
+from typing import Any, Dict, Union
from .exceptions import InvalidKeyError
+from .types import JWKDict
from .utils import (
base64url_decode,
base64url_encode,
@@ -20,10 +22,18 @@ try:
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
- from cryptography.hazmat.primitives.asymmetric import ec, padding
+ from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.ec import (
+ ECDSA,
+ SECP256K1,
+ SECP256R1,
+ SECP384R1,
+ SECP521R1,
+ EllipticCurve,
EllipticCurvePrivateKey,
+ EllipticCurvePrivateNumbers,
EllipticCurvePublicKey,
+ EllipticCurvePublicNumbers,
)
from cryptography.hazmat.primitives.asymmetric.ed448 import (
Ed448PrivateKey,
@@ -73,7 +83,7 @@ requires_cryptography = {
}
-def get_default_algorithms():
+def get_default_algorithms() -> Dict[str, "Algorithm"]:
"""
Returns the algorithms that are implemented by the library.
"""
@@ -130,25 +140,29 @@ class Algorithm:
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
- return digest.finalize()
+ return bytes(digest.finalize())
else:
- return hash_alg(bytestr).digest()
+ return bytes(hash_alg(bytestr).digest())
- def prepare_key(self, key):
+ # TODO: all key-related `Any`s in this class should optimally be made
+ # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
+ # that may still be poorly supported.
+
+ def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
raise NotImplementedError
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
raise NotImplementedError
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
@@ -156,14 +170,14 @@ class Algorithm:
raise NotImplementedError
@staticmethod
- def to_jwk(key_obj):
+ def to_jwk(key_obj) -> JWKDict:
"""
Serializes a given RSA key into a JWK
"""
raise NotImplementedError
@staticmethod
- def from_jwk(jwk):
+ def from_jwk(jwk: JWKDict):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
@@ -202,7 +216,7 @@ class HMACAlgorithm(Algorithm):
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
- def __init__(self, hash_alg):
+ def __init__(self, hash_alg) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key):
@@ -242,7 +256,7 @@ class HMACAlgorithm(Algorithm):
return base64url_decode(obj["k"])
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: bytes) -> bytes:
return hmac.new(key, msg, self.hash_alg).digest()
def verify(self, msg, key, sig):
@@ -261,7 +275,7 @@ if has_crypto:
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
- def __init__(self, hash_alg):
+ def __init__(self, hash_alg) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key):
@@ -271,16 +285,15 @@ if has_crypto:
if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")
- key = force_bytes(key)
+ key_bytes = force_bytes(key)
try:
- if key.startswith(b"ssh-rsa"):
- key = load_ssh_public_key(key)
+ if key_bytes.startswith(b"ssh-rsa"):
+ return load_ssh_public_key(key_bytes)
else:
- key = load_pem_private_key(key, password=None)
+ return load_pem_private_key(key_bytes, password=None)
except ValueError:
- key = load_pem_public_key(key)
- return key
+ return load_pem_public_key(key_bytes)
@staticmethod
def to_jwk(key_obj):
@@ -383,12 +396,10 @@ if has_crypto:
return numbers.private_key()
elif "n" in obj and "e" in obj:
# Public key
- numbers = RSAPublicNumbers(
+ return RSAPublicNumbers(
from_base64url_uint(obj["e"]),
from_base64url_uint(obj["n"]),
- )
-
- return numbers.public_key()
+ ).public_key()
else:
raise InvalidKeyError("Not a public or private key")
@@ -412,7 +423,7 @@ if has_crypto:
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
- def __init__(self, hash_alg):
+ def __init__(self, hash_alg) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key):
@@ -422,18 +433,18 @@ if has_crypto:
if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")
- key = force_bytes(key)
+ key_bytes = force_bytes(key)
# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
# the Verifying Key first.
try:
- if key.startswith(b"ecdsa-sha2-"):
- key = load_ssh_public_key(key)
+ if key_bytes.startswith(b"ecdsa-sha2-"):
+ key = load_ssh_public_key(key_bytes)
else:
- key = load_pem_public_key(key)
+ key = load_pem_public_key(key_bytes)
except ValueError:
- key = load_pem_private_key(key, password=None)
+ key = load_pem_private_key(key_bytes, password=None)
# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
@@ -444,7 +455,7 @@ if has_crypto:
return key
def sign(self, msg, key):
- der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
+ der_sig = key.sign(msg, ECDSA(self.hash_alg()))
return der_to_raw_signature(der_sig, key.curve)
@@ -457,7 +468,7 @@ if has_crypto:
try:
if isinstance(key, EllipticCurvePrivateKey):
key = key.public_key()
- key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
+ key.verify(der_sig, msg, ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
@@ -472,13 +483,13 @@ if has_crypto:
else:
raise InvalidKeyError("Not a public or private key")
- if isinstance(key_obj.curve, ec.SECP256R1):
+ if isinstance(key_obj.curve, SECP256R1):
crv = "P-256"
- elif isinstance(key_obj.curve, ec.SECP384R1):
+ elif isinstance(key_obj.curve, SECP384R1):
crv = "P-384"
- elif isinstance(key_obj.curve, ec.SECP521R1):
+ elif isinstance(key_obj.curve, SECP521R1):
crv = "P-521"
- elif isinstance(key_obj.curve, ec.SECP256K1):
+ elif isinstance(key_obj.curve, SECP256K1):
crv = "secp256k1"
else:
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
@@ -498,7 +509,9 @@ if has_crypto:
return json.dumps(obj)
@staticmethod
- def from_jwk(jwk):
+ def from_jwk(
+ jwk: Any,
+ ) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
@@ -519,24 +532,26 @@ if has_crypto:
y = base64url_decode(obj.get("y"))
curve = obj.get("crv")
+ curve_obj: EllipticCurve
+
if curve == "P-256":
if len(x) == len(y) == 32:
- curve_obj = ec.SECP256R1()
+ curve_obj = SECP256R1()
else:
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
elif curve == "P-384":
if len(x) == len(y) == 48:
- curve_obj = ec.SECP384R1()
+ curve_obj = SECP384R1()
else:
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
elif curve == "P-521":
if len(x) == len(y) == 66:
- curve_obj = ec.SECP521R1()
+ curve_obj = SECP521R1()
else:
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
elif curve == "secp256k1":
if len(x) == len(y) == 32:
- curve_obj = ec.SECP256K1()
+ curve_obj = SECP256K1()
else:
raise InvalidKeyError(
"Coords should be 32 bytes for curve secp256k1"
@@ -544,7 +559,7 @@ if has_crypto:
else:
raise InvalidKeyError(f"Invalid curve: {curve}")
- public_numbers = ec.EllipticCurvePublicNumbers(
+ public_numbers = EllipticCurvePublicNumbers(
x=int.from_bytes(x, byteorder="big"),
y=int.from_bytes(y, byteorder="big"),
curve=curve_obj,
@@ -559,7 +574,7 @@ if has_crypto:
"D should be {} bytes for curve {}", len(x), curve
)
- return ec.EllipticCurvePrivateNumbers(
+ return EllipticCurvePrivateNumbers(
int.from_bytes(d, byteorder="big"), public_numbers
).private_key()
@@ -600,7 +615,7 @@ if has_crypto:
This class requires ``cryptography>=2.6`` to be installed.
"""
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs) -> None:
pass
def prepare_key(self, key):
diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py
index aa3dd32..ef8cce8 100644
--- a/jwt/api_jwk.py
+++ b/jwt/api_jwk.py
@@ -2,13 +2,15 @@ from __future__ import annotations
import json
import time
+from typing import Any, Optional
from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
+from .types import JWKDict
class PyJWK:
- def __init__(self, jwk_data, algorithm=None):
+ def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None:
self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data
@@ -55,29 +57,29 @@ class PyJWK:
self.key = self.Algorithm.from_jwk(self._jwk_data)
@staticmethod
- def from_dict(obj, algorithm=None):
+ def from_dict(obj: JWKDict, algorithm: Optional[str] = None) -> "PyJWK":
return PyJWK(obj, algorithm)
@staticmethod
- def from_json(data, algorithm=None):
+ def from_json(data: str, algorithm: None = None) -> "PyJWK":
obj = json.loads(data)
return PyJWK.from_dict(obj, algorithm)
@property
- def key_type(self):
+ def key_type(self) -> str:
return self._jwk_data.get("kty", None)
@property
- def key_id(self):
+ def key_id(self) -> str:
return self._jwk_data.get("kid", None)
@property
- def public_key_use(self):
+ def public_key_use(self) -> Optional[str]:
return self._jwk_data.get("use", None)
class PyJWKSet:
- def __init__(self, keys: list[dict]) -> None:
+ def __init__(self, keys: list[JWKDict]) -> None:
self.keys = []
if not keys:
@@ -97,16 +99,16 @@ class PyJWKSet:
raise PyJWKSetError("The JWK Set did not contain any usable keys")
@staticmethod
- def from_dict(obj):
+ def from_dict(obj: dict[str, Any]) -> "PyJWKSet":
keys = obj.get("keys", [])
return PyJWKSet(keys)
@staticmethod
- def from_json(data):
+ def from_json(data: str) -> "PyJWKSet":
obj = json.loads(data)
return PyJWKSet.from_dict(obj)
- def __getitem__(self, kid):
+ def __getitem__(self, kid: str) -> "PyJWK":
for key in self.keys:
if key.key_id == kid:
return key
@@ -118,8 +120,8 @@ class PyJWTSetWithTimestamp:
self.jwk_set = jwk_set
self.timestamp = time.monotonic()
- def get_jwk_set(self):
+ def get_jwk_set(self) -> PyJWKSet:
return self.jwk_set
- def get_timestamp(self):
+ def get_timestamp(self) -> float:
return self.timestamp
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index c914db1..0a775d7 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import binascii
import json
import warnings
-from typing import Any, Type
+from typing import Any, List, Optional, Type
from .algorithms import (
Algorithm,
@@ -24,7 +24,11 @@ from .warnings import RemovedInPyjwt3Warning
class PyJWS:
header_typ = "JWT"
- def __init__(self, algorithms=None, options=None) -> None:
+ def __init__(
+ self,
+ algorithms: Optional[List[str]] = None,
+ options: Optional[dict[str, Any]] = None,
+ ) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
set(algorithms) if algorithms is not None else set(self._algorithms)
@@ -164,8 +168,8 @@ class PyJWS:
def decode_complete(
self,
- jwt: str,
- key: str = "",
+ jwt: str | bytes,
+ key: str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
@@ -209,13 +213,13 @@ class PyJWS:
def decode(
self,
- jwt: str,
- key: str = "",
+ jwt: str | bytes,
+ key: str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
- ) -> str:
+ ) -> Any:
if kwargs:
warnings.warn(
"passing additional kwargs to decode() is deprecated "
@@ -228,7 +232,7 @@ class PyJWS:
)
return decoded["payload"]
- def get_unverified_header(self, jwt: str | bytes) -> dict:
+ def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
"""Returns back the JWT header parameters as a dict()
Note: The signature is not verified so the header parameters
@@ -239,7 +243,7 @@ class PyJWS:
return headers
- def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
+ def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]:
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")
@@ -280,13 +284,15 @@ class PyJWS:
def _verify_signature(
self,
signing_input: bytes,
- header: dict,
+ header: dict[str, Any],
signature: bytes,
- key: str = "",
+ key: str | bytes = "",
algorithms: list[str] | None = None,
) -> None:
-
- alg = header.get("alg")
+ try:
+ alg = header["alg"]
+ except KeyError:
+ raise InvalidAlgorithmError("Algorithm not specified")
if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed")
@@ -304,7 +310,7 @@ class PyJWS:
if "kid" in headers:
self._validate_kid(headers["kid"])
- def _validate_kid(self, kid: str) -> None:
+ def _validate_kid(self, kid: Any) -> None:
if not isinstance(kid, str):
raise InvalidTokenError("Key ID header parameter must be a string")
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 5d21afc..8840708 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -21,10 +21,10 @@ from .warnings import RemovedInPyjwt3Warning
class PyJWT:
- def __init__(self, options=None):
+ def __init__(self, options: Optional[dict[str, Any]] = None) -> None:
if options is None:
options = {}
- self.options = {**self._get_default_options(), **options}
+ self.options: dict[str, Any] = {**self._get_default_options(), **options}
@staticmethod
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
@@ -96,8 +96,8 @@ class PyJWT:
def decode_complete(
self,
- jwt: str,
- key: str = "",
+ jwt: str | bytes,
+ key: str | bytes = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
# deprecated arg, remove in pyjwt3
@@ -181,8 +181,8 @@ class PyJWT:
def decode(
self,
- jwt: str,
- key: str = "",
+ jwt: str | bytes,
+ key: str | bytes = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
# deprecated arg, remove in pyjwt3
@@ -196,7 +196,7 @@ class PyJWT:
leeway: Union[int, float, timedelta] = 0,
# kwargs
**kwargs,
- ) -> Dict[str, Any]:
+ ) -> Any:
if kwargs:
warnings.warn(
"passing additional kwargs to decode() is deprecated "
@@ -217,7 +217,14 @@ class PyJWT:
)
return decoded["payload"]
- def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0):
+ def _validate_claims(
+ self,
+ payload: dict[str, Any],
+ options: dict[str, Any],
+ audience=None,
+ issuer=None,
+ leeway: float | timedelta = 0,
+ ) -> None:
if isinstance(leeway, timedelta):
leeway = leeway.total_seconds()
@@ -243,12 +250,21 @@ class PyJWT:
if options["verify_aud"]:
self._validate_aud(payload, audience)
- def _validate_required_claims(self, payload, options):
+ def _validate_required_claims(
+ self,
+ payload: dict[str, Any],
+ options: dict[str, Any],
+ ) -> None:
for claim in options["require"]:
if payload.get(claim) is None:
raise MissingRequiredClaimError(claim)
- def _validate_iat(self, payload, now, leeway):
+ def _validate_iat(
+ self,
+ payload: dict[str, Any],
+ now: float,
+ leeway: float,
+ ) -> None:
iat = payload["iat"]
try:
int(iat)
@@ -257,7 +273,12 @@ class PyJWT:
if iat > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (iat)")
- def _validate_nbf(self, payload, now, leeway):
+ def _validate_nbf(
+ self,
+ payload: dict[str, Any],
+ now: float,
+ leeway: float,
+ ) -> None:
try:
nbf = int(payload["nbf"])
except ValueError:
@@ -266,7 +287,12 @@ class PyJWT:
if nbf > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (nbf)")
- def _validate_exp(self, payload, now, leeway):
+ def _validate_exp(
+ self,
+ payload: dict[str, Any],
+ now: float,
+ leeway: float,
+ ) -> None:
try:
exp = int(payload["exp"])
except ValueError:
@@ -275,7 +301,11 @@ class PyJWT:
if exp <= (now - leeway):
raise ExpiredSignatureError("Signature has expired")
- def _validate_aud(self, payload, audience):
+ def _validate_aud(
+ self,
+ payload: dict[str, Any],
+ audience: Optional[Union[str, Iterable[str]]],
+ ) -> None:
if audience is None:
if "aud" not in payload or not payload["aud"]:
return
@@ -303,7 +333,7 @@ class PyJWT:
if all(aud not in audience_claims for aud in audience):
raise InvalidAudienceError("Invalid audience")
- def _validate_iss(self, payload, issuer):
+ def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
if issuer is None:
return
diff --git a/jwt/exceptions.py b/jwt/exceptions.py
index ee201ad..dbe0056 100644
--- a/jwt/exceptions.py
+++ b/jwt/exceptions.py
@@ -47,10 +47,10 @@ class InvalidAlgorithmError(InvalidTokenError):
class MissingRequiredClaimError(InvalidTokenError):
- def __init__(self, claim):
+ def __init__(self, claim: str) -> None:
self.claim = claim
- def __str__(self):
+ def __str__(self) -> str:
return f'Token is missing the "{self.claim}" claim'
diff --git a/jwt/help.py b/jwt/help.py
index 0c02eb9..80b0ca5 100644
--- a/jwt/help.py
+++ b/jwt/help.py
@@ -7,8 +7,10 @@ from . import __version__ as pyjwt_version
try:
import cryptography
+
+ cryptography_version = cryptography.__version__
except ModuleNotFoundError:
- cryptography = None
+ cryptography_version = ""
def info() -> Dict[str, Dict[str, str]]:
@@ -29,7 +31,7 @@ def info() -> Dict[str, Dict[str, str]]:
if implementation == "CPython":
implementation_version = platform.python_version()
elif implementation == "PyPy":
- pypy_version_info = getattr(sys, "pypy_version_info")
+ pypy_version_info = sys.pypy_version_info # type: ignore[attr-defined]
implementation_version = (
f"{pypy_version_info.major}."
f"{pypy_version_info.minor}."
@@ -48,7 +50,7 @@ def info() -> Dict[str, Dict[str, str]]:
"name": implementation,
"version": implementation_version,
},
- "cryptography": {"version": getattr(cryptography, "__version__", "")},
+ "cryptography": {"version": cryptography_version},
"pyjwt": {"version": pyjwt_version},
}
diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py
index daeb830..aa33bb3 100644
--- a/jwt/jwks_client.py
+++ b/jwt/jwks_client.py
@@ -1,7 +1,7 @@
import json
import urllib.request
from functools import lru_cache
-from typing import Any, List, Optional
+from typing import Any, Dict, List, Optional
from urllib.error import URLError
from .api_jwk import PyJWK, PyJWKSet
@@ -18,8 +18,10 @@ class PyJWKClient:
max_cached_keys: int = 16,
cache_jwk_set: bool = True,
lifespan: int = 300,
- headers: dict = {},
+ headers: Optional[Dict[str, Any]] = None,
):
+ if headers is None:
+ headers = {}
self.uri = uri
self.jwk_set_cache: Optional[JWKSetCache] = None
self.headers = headers
@@ -62,6 +64,9 @@ class PyJWKClient:
if data is None:
data = self.fetch_data()
+ if not isinstance(data, dict):
+ raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
+
return PyJWKSet.from_dict(data)
def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
diff --git a/jwt/types.py b/jwt/types.py
new file mode 100644
index 0000000..830b185
--- /dev/null
+++ b/jwt/types.py
@@ -0,0 +1,3 @@
+from typing import Any, Dict
+
+JWKDict = Dict[str, Any]
diff --git a/jwt/utils.py b/jwt/utils.py
index 16cae06..92e3c8b 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -1,7 +1,7 @@
import base64
import binascii
import re
-from typing import Union
+from typing import Any, AnyStr
try:
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
@@ -10,10 +10,10 @@ try:
encode_dss_signature,
)
except ModuleNotFoundError:
- EllipticCurve = None
+ pass
-def force_bytes(value: Union[str, bytes]) -> bytes:
+def force_bytes(value: Any) -> bytes:
if isinstance(value, str):
return value.encode("utf-8")
elif isinstance(value, bytes):
@@ -22,16 +22,15 @@ def force_bytes(value: Union[str, bytes]) -> bytes:
raise TypeError("Expected a string value")
-def base64url_decode(input: Union[str, bytes]) -> bytes:
- if isinstance(input, str):
- input = input.encode("ascii")
+def base64url_decode(input: AnyStr) -> bytes:
+ input_bytes = force_bytes(input)
- rem = len(input) % 4
+ rem = len(input_bytes) % 4
if rem > 0:
- input += b"=" * (4 - rem)
+ input_bytes += b"=" * (4 - rem)
- return base64.urlsafe_b64decode(input)
+ return base64.urlsafe_b64decode(input_bytes)
def base64url_encode(input: bytes) -> bytes:
@@ -50,11 +49,8 @@ def to_base64url_uint(val: int) -> bytes:
return base64url_encode(int_bytes)
-def from_base64url_uint(val: Union[str, bytes]) -> int:
- if isinstance(val, str):
- val = val.encode("ascii")
-
- data = base64url_decode(val)
+def from_base64url_uint(val: AnyStr) -> int:
+ data = base64url_decode(force_bytes(val))
return int.from_bytes(data, byteorder="big")
@@ -78,7 +74,7 @@ def bytes_from_int(val: int) -> bytes:
return val.to_bytes(byte_length, "big", signed=False)
-def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
+def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes:
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
@@ -87,7 +83,7 @@ def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
-def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
+def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes:
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
@@ -97,7 +93,7 @@ def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
r = bytes_to_number(raw_sig[:num_bytes])
s = bytes_to_number(raw_sig[num_bytes:])
- return encode_dss_signature(r, s)
+ return bytes(encode_dss_signature(r, s))
# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252
diff --git a/pyproject.toml b/pyproject.toml
index f406505..e7920ff 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,3 +23,11 @@ python_version = 3.7
ignore_missing_imports = true
warn_unused_ignores = true
no_implicit_optional = true
+strict = true
+# TODO: remove these strict loosenings when possible
+allow_incomplete_defs = true
+allow_untyped_defs = true
+
+[[tool.mypy.overrides]]
+module = "tests.*"
+disallow_untyped_calls = false
diff --git a/tests/test_advisory.py b/tests/test_advisory.py
index ed768d4..8a7e05e 100644
--- a/tests/test_advisory.py
+++ b/tests/test_advisory.py
@@ -1,6 +1,7 @@
import pytest
import jwt
+from jwt.algorithms import get_default_algorithms
from jwt.exceptions import InvalidKeyError
from .utils import crypto_required
@@ -50,18 +51,20 @@ class TestAdvisory:
# public key is a HMAC secret
encoded_bad = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ0ZXN0IjoxMjM0fQ.6ulDpqSlbHmQ8bZXhZRLFko9SwcHrghCwh8d-exJEE4"
+ algorithm_names = list(get_default_algorithms())
+
# Both of the jwt tokens are validated as valid
jwt.decode(
encoded_good,
pub_key_bytes,
- algorithms=jwt.algorithms.get_default_algorithms(),
+ algorithms=algorithm_names,
)
with pytest.raises(InvalidKeyError):
jwt.decode(
encoded_bad,
pub_key_bytes,
- algorithms=jwt.algorithms.get_default_algorithms(),
+ algorithms=algorithm_names,
)
# Of course the receiver should specify ed25519 algorithm to be used if
@@ -100,16 +103,17 @@ class TestAdvisory:
# encoded_bad = jwt.encode({"test": 1234}, ssh_key_bytes, algorithm="HS256")
encoded_bad = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0ZXN0IjoxMjM0fQ.5eYfbrbeGYmWfypQ6rMWXNZ8bdHcqKng5GPr9MJZITU"
+ algorithm_names = list(get_default_algorithms())
# Both of the jwt tokens are validated as valid
jwt.decode(
encoded_good,
ssh_key_bytes,
- algorithms=jwt.algorithms.get_default_algorithms(),
+ algorithms=algorithm_names,
)
with pytest.raises(InvalidKeyError):
jwt.decode(
encoded_bad,
ssh_key_bytes,
- algorithms=jwt.algorithms.get_default_algorithms(),
+ algorithms=algorithm_names,
)
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index 894ce28..3ace425 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -25,19 +25,19 @@ class TestAlgorithms:
algo = Algorithm()
with pytest.raises(NotImplementedError):
- algo.sign("message", "key")
+ algo.sign(b"message", "key")
def test_algorithm_should_throw_exception_if_verify_not_impl(self):
algo = Algorithm()
with pytest.raises(NotImplementedError):
- algo.verify("message", "key", "signature")
+ algo.verify(b"message", "key", b"signature")
def test_algorithm_should_throw_exception_if_to_jwk_not_impl(self):
algo = Algorithm()
with pytest.raises(NotImplementedError):
- algo.from_jwk("value")
+ algo.from_jwk({"val": "ue"})
def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self):
algo = Algorithm()
diff --git a/tests/test_api_jwk.py b/tests/test_api_jwk.py
index 040e81b..3f40fc5 100644
--- a/tests/test_api_jwk.py
+++ b/tests/test_api_jwk.py
@@ -293,7 +293,7 @@ class TestPyJWKSet:
def test_invalid_keys_list(self):
with pytest.raises(PyJWKSetError) as err:
- PyJWKSet(keys="string")
+ PyJWKSet(keys="string") # type: ignore
assert str(err.value) == "Invalid JWK Set value"
def test_empty_keys_list(self):