summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorViicos <65306057+Viicos@users.noreply.github.com>2023-04-16 09:51:09 +0200
committerGitHub <noreply@github.com>2023-04-16 13:51:09 +0600
commit56b3d5633160e79e1f4c5c09023d68759cbf84a6 (patch)
tree8daa4bef8877eb6dcbb742bb0f63842ca70a51e6
parentba726444a6cee75af59feb8ea08294d0ac89bedb (diff)
downloadpyjwt-56b3d5633160e79e1f4c5c09023d68759cbf84a6.tar.gz
Add complete types to take all allowed keys into account (#873)
* Use new style typing * Fix type annotations to allow all keys * Use string type annotations where required * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove outdated comment * Ignore `if TYPE_CHECKING:` lines in coverage * Remove duplicate test * Fix mypy errors * Update algorithms.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fully switch to modern annotations * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update `pre-commit` mypy config * Use Python 3.11 for mypy * Update mypy Python version in `pyproject.toml` * Few tests mypy fixes * fix mypy errors on tests * Fix key imports * Remove unused import * Fix randomly failing test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Asif Saif Uddin <auvipy@gmail.com>
-rw-r--r--.pre-commit-config.yaml1
-rw-r--r--jwt/algorithms.py189
-rw-r--r--jwt/api_jwk.py12
-rw-r--r--jwt/api_jws.py23
-rw-r--r--jwt/api_jwt.py67
-rw-r--r--jwt/utils.py8
-rw-r--r--pyproject.toml3
-rw-r--r--tests/test_algorithms.py116
-rw-r--r--tests/test_api_jws.py10
-rw-r--r--tests/test_api_jwt.py4
-rw-r--r--tests/test_utils.py2
11 files changed, 241 insertions, 194 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 3af45df..25996a7 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -39,3 +39,4 @@ repos:
rev: "v1.1.1"
hooks:
- id: mypy
+ additional_dependencies: [cryptography>=3.4.0]
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index bc928fe..4bcf81b 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
import hashlib
import hmac
import json
from abc import ABC, abstractmethod
-from typing import Any, ClassVar, Dict, Type, Union
+from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, cast
from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict
@@ -19,7 +21,6 @@ from .utils import (
)
try:
- import cryptography.exceptions
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
@@ -68,6 +69,23 @@ try:
except ModuleNotFoundError:
has_crypto = False
+
+if TYPE_CHECKING:
+ # Type aliases for convenience in algorithms method signatures
+ AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
+ AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
+ AllowedOKPKeys = (
+ Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
+ )
+ AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
+ AllowedPrivateKeys = (
+ RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
+ )
+ AllowedPublicKeys = (
+ RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
+ )
+
+
requires_cryptography = {
"RS256",
"RS384",
@@ -84,7 +102,7 @@ requires_cryptography = {
}
-def get_default_algorithms() -> Dict[str, "Algorithm"]:
+def get_default_algorithms() -> dict[str, Algorithm]:
"""
Returns the algorithms that are implemented by the library.
"""
@@ -145,10 +163,6 @@ class Algorithm(ABC):
else:
return bytes(hash_alg(bytestr).digest())
- # TODO: all key-related `Any`s in this class should optimally be made
- # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
- # that may still be poorly supported.
-
@abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
@@ -172,16 +186,16 @@ class Algorithm(ABC):
@staticmethod
@abstractmethod
- def to_jwk(key_obj) -> JWKDict:
+ def to_jwk(key_obj) -> str:
"""
- Serializes a given RSA key into a JWK
+ Serializes a given key into a JWK
"""
@staticmethod
@abstractmethod
- def from_jwk(jwk: JWKDict):
+ def from_jwk(jwk: str | JWKDict) -> Any:
"""
- Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
+ Deserializes a given key from JWK back into a key object
"""
@@ -191,7 +205,7 @@ class NoneAlgorithm(Algorithm):
operations are required.
"""
- def prepare_key(self, key):
+ def prepare_key(self, key: str | None) -> None:
if key == "":
key = None
@@ -200,18 +214,18 @@ class NoneAlgorithm(Algorithm):
return key
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: None) -> bytes:
return b""
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
return False
@staticmethod
- def to_jwk(key_obj) -> JWKDict:
+ def to_jwk(key_obj: Any) -> NoReturn:
raise NotImplementedError()
@staticmethod
- def from_jwk(jwk: JWKDict):
+ def from_jwk(jwk: str | JWKDict) -> NoReturn:
raise NotImplementedError()
@@ -228,19 +242,19 @@ class HMACAlgorithm(Algorithm):
def __init__(self, hash_alg: HashlibHash) -> None:
self.hash_alg = hash_alg
- def prepare_key(self, key):
- key = force_bytes(key)
+ def prepare_key(self, key: str | bytes) -> bytes:
+ key_bytes = force_bytes(key)
- if is_pem_format(key) or is_ssh_key(key):
+ if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
raise InvalidKeyError(
"The specified key is an asymmetric key or x509 certificate and"
" should not be used as an HMAC secret."
)
- return key
+ return key_bytes
@staticmethod
- def to_jwk(key_obj):
+ def to_jwk(key_obj: str | bytes) -> str:
return json.dumps(
{
"k": base64url_encode(force_bytes(key_obj)).decode(),
@@ -249,10 +263,10 @@ class HMACAlgorithm(Algorithm):
)
@staticmethod
- def from_jwk(jwk):
+ def from_jwk(jwk: str | JWKDict) -> bytes:
try:
if isinstance(jwk, str):
- obj = json.loads(jwk)
+ obj: JWKDict = json.loads(jwk)
elif isinstance(jwk, dict):
obj = jwk
else:
@@ -268,7 +282,7 @@ class HMACAlgorithm(Algorithm):
def sign(self, msg: bytes, key: bytes) -> bytes:
return hmac.new(key, msg, self.hash_alg).digest()
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
return hmac.compare_digest(sig, self.sign(msg, key))
@@ -280,14 +294,14 @@ if has_crypto:
RSASSA-PKCS-v1_5 and the specified hash function.
"""
- SHA256: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA256
- SHA384: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA384
- SHA512: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA512
+ SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
+ SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
+ SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
- def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None:
+ def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg
- def prepare_key(self, key):
+ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
return key
@@ -298,15 +312,17 @@ if has_crypto:
try:
if key_bytes.startswith(b"ssh-rsa"):
- return load_ssh_public_key(key_bytes)
+ return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
else:
- return load_pem_private_key(key_bytes, password=None)
+ return cast(
+ RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
+ )
except ValueError:
- return load_pem_public_key(key_bytes)
+ return cast(RSAPublicKey, load_pem_public_key(key_bytes))
@staticmethod
- def to_jwk(key_obj):
- obj = None
+ def to_jwk(key_obj: AllowedRSAKeys) -> str:
+ obj: dict[str, Any] | None = None
if hasattr(key_obj, "private_numbers"):
# Private key
@@ -341,7 +357,7 @@ if has_crypto:
return json.dumps(obj)
@staticmethod
- def from_jwk(jwk):
+ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
@@ -412,10 +428,10 @@ if has_crypto:
else:
raise InvalidKeyError("Not a public or private key")
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
try:
key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
return True
@@ -428,14 +444,14 @@ if has_crypto:
ECDSA and the specified hash function
"""
- SHA256: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA256
- SHA384: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA384
- SHA512: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA512
+ SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
+ SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
+ SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
- def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None:
+ def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg
- def prepare_key(self, key):
+ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
return key
@@ -449,41 +465,46 @@ if has_crypto:
# the Verifying Key first.
try:
if key_bytes.startswith(b"ecdsa-sha2-"):
- key = load_ssh_public_key(key_bytes)
+ crypto_key = load_ssh_public_key(key_bytes)
else:
- key = load_pem_public_key(key_bytes)
+ crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
except ValueError:
- key = load_pem_private_key(key_bytes, password=None)
+ crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
# Explicit check the key to prevent confusing errors from cryptography
- if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
+ if not isinstance(
+ crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
+ ):
raise InvalidKeyError(
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
)
- return key
+ return crypto_key
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
der_sig = key.sign(msg, ECDSA(self.hash_alg()))
return der_to_raw_signature(der_sig, key.curve)
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
try:
der_sig = raw_to_der_signature(sig, key.curve)
except ValueError:
return False
try:
- if isinstance(key, EllipticCurvePrivateKey):
- key = key.public_key()
- key.verify(der_sig, msg, ECDSA(self.hash_alg()))
+ public_key = (
+ key.public_key()
+ if isinstance(key, EllipticCurvePrivateKey)
+ else key
+ )
+ public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
@staticmethod
- def to_jwk(key_obj):
+ def to_jwk(key_obj: AllowedECKeys) -> str:
if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey):
@@ -502,7 +523,7 @@ if has_crypto:
else:
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
- obj = {
+ obj: dict[str, Any] = {
"kty": "EC",
"crv": crv,
"x": to_base64url_uint(public_numbers.x).decode(),
@@ -517,9 +538,7 @@ if has_crypto:
return json.dumps(obj)
@staticmethod
- def from_jwk(
- jwk: Any,
- ) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]:
+ def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
@@ -591,7 +610,7 @@ if has_crypto:
Performs a signature using RSASSA-PSS with MGF1
"""
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
return key.sign(
msg,
padding.PSS(
@@ -601,7 +620,7 @@ if has_crypto:
self.hash_alg(),
)
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
try:
key.verify(
sig,
@@ -623,21 +642,20 @@ if has_crypto:
This class requires ``cryptography>=2.6`` to be installed.
"""
- def __init__(self, **kwargs) -> None:
+ def __init__(self, **kwargs: Any) -> None:
pass
- def prepare_key(self, key):
+ def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
if isinstance(key, (bytes, str)):
- if isinstance(key, str):
- key = key.encode("utf-8")
- str_key = key.decode("utf-8")
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
+ key_bytes = key.encode("utf-8") if isinstance(key, str) else key
- if "-----BEGIN PUBLIC" in str_key:
- key = load_pem_public_key(key)
- elif "-----BEGIN PRIVATE" in str_key:
- key = load_pem_private_key(key, password=None)
- elif str_key[0:4] == "ssh-":
- key = load_ssh_public_key(key)
+ if "-----BEGIN PUBLIC" in key_str:
+ key = load_pem_public_key(key_bytes) # type: ignore[assignment]
+ elif "-----BEGIN PRIVATE" in key_str:
+ key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
+ elif key_str[0:4] == "ssh-":
+ key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(
@@ -650,7 +668,9 @@ if has_crypto:
return key
- def sign(self, msg, key):
+ def sign(
+ self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
+ ) -> bytes:
"""
Sign a message ``msg`` using the EdDSA private key ``key``
:param str|bytes msg: Message to sign
@@ -658,10 +678,12 @@ if has_crypto:
or :class:`.Ed448PrivateKey` isinstance
:return bytes signature: The signature, as bytes
"""
- msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
- return key.sign(msg)
+ msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
+ return key.sign(msg_bytes)
- def verify(self, msg, key, sig):
+ def verify(
+ self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
+ ) -> bool:
"""
Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
@@ -672,18 +694,21 @@ if has_crypto:
:return bool verified: True if signature is valid, False if not.
"""
try:
- msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
- sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig
+ msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
+ sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
- if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
- key = key.public_key()
- key.verify(sig, msg)
+ public_key = (
+ key.public_key()
+ if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
+ else key
+ )
+ public_key.verify(sig_bytes, msg_bytes)
return True # If no exception was raised, the signature is valid.
- except cryptography.exceptions.InvalidSignature:
+ except InvalidSignature:
return False
@staticmethod
- def to_jwk(key):
+ def to_jwk(key: AllowedOKPKeys) -> str:
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
x = key.public_bytes(
encoding=Encoding.Raw,
@@ -723,7 +748,7 @@ if has_crypto:
raise InvalidKeyError("Not a public or private key")
@staticmethod
- def from_jwk(jwk):
+ def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py
index fdcde21..dd876f7 100644
--- a/jwt/api_jwk.py
+++ b/jwt/api_jwk.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import json
import time
-from typing import Any, Optional
+from typing import Any
from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
@@ -10,7 +10,7 @@ from .types import JWKDict
class PyJWK:
- def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None:
+ def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None:
self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data
@@ -60,7 +60,7 @@ class PyJWK:
self.key = self.Algorithm.from_jwk(self._jwk_data)
@staticmethod
- def from_dict(obj: JWKDict, algorithm: Optional[str] = None) -> "PyJWK":
+ def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK":
return PyJWK(obj, algorithm)
@staticmethod
@@ -69,15 +69,15 @@ class PyJWK:
return PyJWK.from_dict(obj, algorithm)
@property
- def key_type(self) -> Optional[str]:
+ def key_type(self) -> str | None:
return self._jwk_data.get("kty", None)
@property
- def key_id(self) -> Optional[str]:
+ def key_id(self) -> str | None:
return self._jwk_data.get("kid", None)
@property
- def public_key_use(self) -> Optional[str]:
+ def public_key_use(self) -> str | None:
return self._jwk_data.get("use", None)
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index 0a775d7..fa6708c 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import binascii
import json
import warnings
-from typing import Any, List, Optional, Type
+from typing import TYPE_CHECKING, Any
from .algorithms import (
Algorithm,
@@ -20,14 +20,17 @@ from .exceptions import (
from .utils import base64url_decode, base64url_encode
from .warnings import RemovedInPyjwt3Warning
+if TYPE_CHECKING:
+ from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
+
class PyJWS:
header_typ = "JWT"
def __init__(
self,
- algorithms: Optional[List[str]] = None,
- options: Optional[dict[str, Any]] = None,
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
@@ -100,10 +103,10 @@ class PyJWS:
def encode(
self,
payload: bytes,
- key: str,
+ key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
- json_encoder: Type[json.JSONEncoder] | None = None,
+ json_encoder: type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
sort_headers: bool = True,
) -> str:
@@ -169,7 +172,7 @@ class PyJWS:
def decode_complete(
self,
jwt: str | bytes,
- key: str | bytes = "",
+ key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
@@ -214,7 +217,7 @@ class PyJWS:
def decode(
self,
jwt: str | bytes,
- key: str | bytes = "",
+ key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
@@ -286,7 +289,7 @@ class PyJWS:
signing_input: bytes,
header: dict[str, Any],
signature: bytes,
- key: str | bytes = "",
+ key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
) -> None:
try:
@@ -301,9 +304,9 @@ class PyJWS:
alg_obj = self.get_algorithm_by_name(alg)
except NotImplementedError as e:
raise InvalidAlgorithmError("Algorithm not supported") from e
- key = alg_obj.prepare_key(key)
+ prepared_key = alg_obj.prepare_key(key)
- if not alg_obj.verify(signing_input, key, signature):
+ if not alg_obj.verify(signing_input, prepared_key, signature):
raise InvalidSignatureError("Signature verification failed")
def _validate_headers(self, headers: dict[str, Any]) -> None:
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index d85f6e8..49d1b48 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -5,7 +5,7 @@ import warnings
from calendar import timegm
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
-from typing import Any, Dict, List, Optional, Type, Union
+from typing import TYPE_CHECKING, Any
from . import api_jws
from .exceptions import (
@@ -19,15 +19,18 @@ from .exceptions import (
)
from .warnings import RemovedInPyjwt3Warning
+if TYPE_CHECKING:
+ from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
+
class PyJWT:
- def __init__(self, options: Optional[dict[str, Any]] = None) -> None:
+ def __init__(self, options: dict[str, Any] | None = None) -> None:
if options is None:
options = {}
self.options: dict[str, Any] = {**self._get_default_options(), **options}
@staticmethod
- def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
+ def _get_default_options() -> dict[str, bool | list[str]]:
return {
"verify_signature": True,
"verify_exp": True,
@@ -40,11 +43,11 @@ class PyJWT:
def encode(
self,
- payload: Dict[str, Any],
- key: str,
- algorithm: Optional[str] = "HS256",
- headers: Optional[Dict[str, Any]] = None,
- json_encoder: Optional[Type[json.JSONEncoder]] = None,
+ payload: dict[str, Any],
+ key: AllowedPrivateKeys | str | bytes,
+ algorithm: str | None = "HS256",
+ headers: dict[str, Any] | None = None,
+ json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
) -> str:
# Check that we get a dict
@@ -78,9 +81,9 @@ class PyJWT:
def _encode_payload(
self,
- payload: Dict[str, Any],
- headers: Optional[Dict[str, Any]] = None,
- json_encoder: Optional[Type[json.JSONEncoder]] = None,
+ payload: dict[str, Any],
+ headers: dict[str, Any] | None = None,
+ json_encoder: type[json.JSONEncoder] | None = None,
) -> bytes:
"""
Encode a given payload to the bytes to be signed.
@@ -97,21 +100,21 @@ class PyJWT:
def decode_complete(
self,
jwt: str | bytes,
- key: str | bytes = "",
- algorithms: Optional[List[str]] = None,
- options: Optional[Dict[str, Any]] = None,
+ key: AllowedPublicKeys | str | bytes = "",
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
- verify: Optional[bool] = None,
+ verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
- detached_payload: Optional[bytes] = None,
+ detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims
# consider putting in options
- audience: Optional[Union[str, Iterable[str]]] = None,
- issuer: Optional[str] = None,
- leeway: Union[int, float, timedelta] = 0,
+ audience: str | Iterable[str] | None = None,
+ issuer: str | None = None,
+ leeway: float | timedelta = 0,
# kwargs
- **kwargs,
- ) -> Dict[str, Any]:
+ **kwargs: Any,
+ ) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
@@ -163,7 +166,7 @@ class PyJWT:
decoded["payload"] = payload
return decoded
- def _decode_payload(self, decoded: Dict[str, Any]) -> Any:
+ def _decode_payload(self, decoded: dict[str, Any]) -> Any:
"""
Decode the payload from a JWS dictionary (payload, signature, header).
@@ -182,20 +185,20 @@ class PyJWT:
def decode(
self,
jwt: str | bytes,
- key: str | bytes = "",
- algorithms: Optional[List[str]] = None,
- options: Optional[Dict[str, Any]] = None,
+ key: AllowedPublicKeys | str | bytes = "",
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
- verify: Optional[bool] = None,
+ verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
- detached_payload: Optional[bytes] = None,
+ detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims
# consider putting in options
- audience: Optional[Union[str, Iterable[str]]] = None,
- issuer: Optional[str] = None,
- leeway: Union[int, float, timedelta] = 0,
+ audience: str | Iterable[str] | None = None,
+ issuer: str | None = None,
+ leeway: float | timedelta = 0,
# kwargs
- **kwargs,
+ **kwargs: Any,
) -> Any:
if kwargs:
warnings.warn(
@@ -303,7 +306,7 @@ class PyJWT:
def _validate_aud(
self,
payload: dict[str, Any],
- audience: Optional[Union[str, Iterable[str]]],
+ audience: str | Iterable[str] | None,
) -> None:
if audience is None:
if "aud" not in payload or not payload["aud"]:
diff --git a/jwt/utils.py b/jwt/utils.py
index 92e3c8b..81c5ee4 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -1,7 +1,7 @@
import base64
import binascii
import re
-from typing import Any, AnyStr
+from typing import Union
try:
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
@@ -13,7 +13,7 @@ except ModuleNotFoundError:
pass
-def force_bytes(value: Any) -> bytes:
+def force_bytes(value: Union[bytes, str]) -> bytes:
if isinstance(value, str):
return value.encode("utf-8")
elif isinstance(value, bytes):
@@ -22,7 +22,7 @@ def force_bytes(value: Any) -> bytes:
raise TypeError("Expected a string value")
-def base64url_decode(input: AnyStr) -> bytes:
+def base64url_decode(input: Union[bytes, str]) -> bytes:
input_bytes = force_bytes(input)
rem = len(input_bytes) % 4
@@ -49,7 +49,7 @@ def to_base64url_uint(val: int) -> bytes:
return base64url_encode(int_bytes)
-def from_base64url_uint(val: AnyStr) -> int:
+def from_base64url_uint(val: Union[bytes, str]) -> int:
data = base64url_decode(force_bytes(val))
return int.from_bytes(data, byteorder="big")
diff --git a/pyproject.toml b/pyproject.toml
index e7920ff..40dfb7b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,6 +12,7 @@ source = ["jwt", ".tox/*/site-packages"]
[tool.coverage.report]
show_missing = true
+exclude_lines = ["if TYPE_CHECKING:"]
[tool.isort]
profile = "black"
@@ -19,7 +20,7 @@ atomic = true
combine_as_imports = true
[tool.mypy]
-python_version = 3.7
+python_version = 3.11
ignore_missing_imports = true
warn_unused_ignores = true
no_implicit_optional = true
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index 8aa9ad7..7a90376 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -1,5 +1,6 @@
import base64
import json
+from typing import cast
import pytest
@@ -11,6 +12,23 @@ from .keys import load_ec_pub_key_p_521, load_hmac_key, load_rsa_pub_key
from .utils import crypto_required, key_path
if has_crypto:
+ from cryptography.hazmat.primitives.asymmetric.ec import (
+ EllipticCurvePrivateKey,
+ EllipticCurvePublicKey,
+ )
+ from cryptography.hazmat.primitives.asymmetric.ed448 import (
+ Ed448PrivateKey,
+ Ed448PublicKey,
+ )
+ from cryptography.hazmat.primitives.asymmetric.ed25519 import (
+ Ed25519PrivateKey,
+ Ed25519PublicKey,
+ )
+ from cryptography.hazmat.primitives.asymmetric.rsa import (
+ RSAPrivateKey,
+ RSAPublicKey,
+ )
+
from jwt.algorithms import ECAlgorithm, OKPAlgorithm, RSAAlgorithm, RSAPSSAlgorithm
@@ -37,7 +55,7 @@ class TestAlgorithms:
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
with pytest.raises(TypeError) as context:
- algo.prepare_key(object())
+ algo.prepare_key(object()) # type: ignore[arg-type]
exception = context.value
assert str(exception) == "Expected a string value"
@@ -112,7 +130,7 @@ class TestAlgorithms:
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
with pytest.raises(TypeError):
- algo.prepare_key(None)
+ algo.prepare_key(None) # type: ignore[arg-type]
@crypto_required
def test_rsa_verify_should_return_false_if_signature_invalid(self):
@@ -132,7 +150,7 @@ class TestAlgorithms:
sig += b"123" # Signature is now invalid
with open(key_path("testkey_rsa.pub")) as keyfile:
- pub_key = algo.prepare_key(keyfile.read())
+ pub_key = cast(RSAPublicKey, algo.prepare_key(keyfile.read()))
result = algo.verify(message, pub_key, sig)
assert not result
@@ -149,10 +167,10 @@ class TestAlgorithms:
algo = ECAlgorithm(hash)
with open(key_path(f"jwk_ec_pub_{curve}.json")) as keyfile:
- pub_key = algo.from_jwk(keyfile.read())
+ pub_key = cast(EllipticCurvePublicKey, algo.from_jwk(keyfile.read()))
with open(key_path(f"jwk_ec_key_{curve}.json")) as keyfile:
- priv_key = algo.from_jwk(keyfile.read())
+ priv_key = cast(EllipticCurvePrivateKey, algo.from_jwk(keyfile.read()))
signature = algo.sign(b"Hello World!", priv_key)
assert algo.verify(b"Hello World!", pub_key, signature)
@@ -223,9 +241,9 @@ class TestAlgorithms:
algo = ECAlgorithm(ECAlgorithm.SHA256)
with open(key_path("testkey_ec.priv")) as ec_key:
- orig_key = algo.prepare_key(ec_key.read())
+ orig_key = cast(EllipticCurvePrivateKey, algo.prepare_key(ec_key.read()))
- parsed_key = algo.from_jwk(algo.to_jwk(orig_key))
+ parsed_key = cast(EllipticCurvePrivateKey, algo.from_jwk(algo.to_jwk(orig_key)))
assert parsed_key.private_numbers() == orig_key.private_numbers()
assert (
parsed_key.private_numbers().public_numbers
@@ -237,9 +255,9 @@ class TestAlgorithms:
algo = ECAlgorithm(ECAlgorithm.SHA256)
with open(key_path("testkey_ec.pub")) as ec_key:
- orig_key = algo.prepare_key(ec_key.read())
+ orig_key = cast(EllipticCurvePublicKey, algo.prepare_key(ec_key.read()))
- parsed_key = algo.from_jwk(algo.to_jwk(orig_key))
+ parsed_key = cast(EllipticCurvePublicKey, algo.from_jwk(algo.to_jwk(orig_key)))
assert parsed_key.public_numbers() == orig_key.public_numbers()
@crypto_required
@@ -284,7 +302,7 @@ class TestAlgorithms:
algo = ECAlgorithm(ECAlgorithm.SHA256)
with pytest.raises(InvalidKeyError):
- algo.to_jwk({"not": "a valid key"})
+ algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type]
@crypto_required
def test_ec_to_jwk_with_valid_curves(self):
@@ -320,10 +338,10 @@ class TestAlgorithms:
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
with open(key_path("jwk_rsa_pub.json")) as keyfile:
- pub_key = algo.from_jwk(keyfile.read())
+ pub_key = cast(RSAPublicKey, algo.from_jwk(keyfile.read()))
with open(key_path("jwk_rsa_key.json")) as keyfile:
- priv_key = algo.from_jwk(keyfile.read())
+ priv_key = cast(RSAPrivateKey, algo.from_jwk(keyfile.read()))
signature = algo.sign(b"Hello World!", priv_key)
assert algo.verify(b"Hello World!", pub_key, signature)
@@ -333,9 +351,9 @@ class TestAlgorithms:
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
with open(key_path("testkey_rsa.priv")) as rsa_key:
- orig_key = algo.prepare_key(rsa_key.read())
+ orig_key = cast(RSAPrivateKey, algo.prepare_key(rsa_key.read()))
- parsed_key = algo.from_jwk(algo.to_jwk(orig_key))
+ parsed_key = cast(RSAPrivateKey, algo.from_jwk(algo.to_jwk(orig_key)))
assert parsed_key.private_numbers() == orig_key.private_numbers()
assert (
parsed_key.private_numbers().public_numbers
@@ -347,9 +365,9 @@ class TestAlgorithms:
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
with open(key_path("testkey_rsa.pub")) as rsa_key:
- orig_key = algo.prepare_key(rsa_key.read())
+ orig_key = cast(RSAPublicKey, algo.prepare_key(rsa_key.read()))
- parsed_key = algo.from_jwk(algo.to_jwk(orig_key))
+ parsed_key = cast(RSAPublicKey, algo.from_jwk(algo.to_jwk(orig_key)))
assert parsed_key.public_numbers() == orig_key.public_numbers()
@crypto_required
@@ -380,14 +398,16 @@ class TestAlgorithms:
with open(key_path("jwk_rsa_key.json")) as keyfile:
keybytes = keyfile.read()
- control_key = algo.from_jwk(keybytes).private_numbers()
+ control_key = cast(RSAPrivateKey, algo.from_jwk(keybytes)).private_numbers()
keydata = json.loads(keybytes)
delete_these = ["p", "q", "dp", "dq", "qi"]
for field in delete_these:
del keydata[field]
- parsed_key = algo.from_jwk(json.dumps(keydata)).private_numbers()
+ parsed_key = cast(
+ RSAPrivateKey, algo.from_jwk(json.dumps(keydata))
+ ).private_numbers()
assert control_key.d == parsed_key.d
assert control_key.p == parsed_key.p
@@ -505,7 +525,7 @@ class TestAlgorithms:
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
with pytest.raises(InvalidKeyError):
- algo.to_jwk({"not": "a valid key"})
+ algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type]
@crypto_required
def test_rsa_from_jwk_raises_exception_on_invalid_key(self):
@@ -520,7 +540,7 @@ class TestAlgorithms:
algo = ECAlgorithm(ECAlgorithm.SHA256)
with pytest.raises(TypeError):
- algo.prepare_key(None)
+ algo.prepare_key(None) # type: ignore[arg-type]
@crypto_required
def test_ec_should_accept_pem_private_key_bytes(self):
@@ -590,11 +610,11 @@ class TestAlgorithms:
message = b"Hello World!"
with open(key_path("testkey_rsa.priv")) as keyfile:
- priv_key = algo.prepare_key(keyfile.read())
+ priv_key = cast(RSAPrivateKey, algo.prepare_key(keyfile.read()))
sig = algo.sign(message, priv_key)
with open(key_path("testkey_rsa.pub")) as keyfile:
- pub_key = algo.prepare_key(keyfile.read())
+ pub_key = cast(RSAPublicKey, algo.prepare_key(keyfile.read()))
result = algo.verify(message, pub_key, sig)
assert result
@@ -617,7 +637,7 @@ class TestAlgorithms:
jwt_sig += b"123" # Signature is now invalid
with open(key_path("testkey_rsa.pub")) as keyfile:
- jwt_pub_key = algo.prepare_key(keyfile.read())
+ jwt_pub_key = cast(RSAPublicKey, algo.prepare_key(keyfile.read()))
result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
assert not result
@@ -678,7 +698,7 @@ class TestAlgorithmsRFC7520:
)
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
- key = algo.prepare_key(load_rsa_pub_key())
+ key = cast(RSAPublicKey, algo.prepare_key(load_rsa_pub_key()))
result = algo.verify(signing_input, key, signature)
assert result
@@ -709,7 +729,7 @@ class TestAlgorithmsRFC7520:
)
algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384)
- key = algo.prepare_key(load_rsa_pub_key())
+ key = cast(RSAPublicKey, algo.prepare_key(load_rsa_pub_key()))
result = algo.verify(signing_input, key, signature)
assert result
@@ -759,7 +779,7 @@ class TestOKPAlgorithms:
algo = OKPAlgorithm()
with pytest.raises(InvalidKeyError):
- algo.prepare_key(None)
+ algo.prepare_key(None) # type: ignore[arg-type]
with open(key_path("testkey_ed25519")) as keyfile:
algo.prepare_key(keyfile.read())
@@ -767,12 +787,6 @@ class TestOKPAlgorithms:
with open(key_path("testkey_ed25519.pub")) as keyfile:
algo.prepare_key(keyfile.read())
- def test_okp_ed25519_should_accept_unicode_key(self):
- algo = OKPAlgorithm()
-
- with open(key_path("testkey_ed25519")) as ec_key:
- algo.prepare_key(ec_key.read())
-
def test_okp_ed25519_sign_should_generate_correct_signature_value(self):
algo = OKPAlgorithm()
@@ -781,10 +795,10 @@ class TestOKPAlgorithms:
expected_sig = base64.b64decode(self.hello_world_sig)
with open(key_path("testkey_ed25519")) as keyfile:
- jwt_key = algo.prepare_key(keyfile.read())
+ jwt_key = cast(Ed25519PrivateKey, algo.prepare_key(keyfile.read()))
with open(key_path("testkey_ed25519.pub")) as keyfile:
- jwt_pub_key = algo.prepare_key(keyfile.read())
+ jwt_pub_key = cast(Ed25519PublicKey, algo.prepare_key(keyfile.read()))
algo.sign(jwt_message, jwt_key)
result = algo.verify(jwt_message, jwt_pub_key, expected_sig)
@@ -829,7 +843,7 @@ class TestOKPAlgorithms:
algo = OKPAlgorithm()
with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile:
- key = algo.from_jwk(keyfile.read())
+ key = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read()))
signature = algo.sign(b"Hello World!", key)
assert algo.verify(b"Hello World!", key.public_key(), signature)
@@ -840,7 +854,7 @@ class TestOKPAlgorithms:
algo = OKPAlgorithm()
with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile:
- key = algo.from_jwk(keyfile.read())
+ key = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read()))
signature = algo.sign(b"Hello World!", key)
assert algo.verify(b"Hello World!", key, signature)
@@ -849,10 +863,10 @@ class TestOKPAlgorithms:
algo = OKPAlgorithm()
with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile:
- priv_key = algo.from_jwk(keyfile.read())
+ priv_key = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read()))
with open(key_path("jwk_okp_pub_Ed25519.json")) as keyfile:
- pub_key = algo.from_jwk(keyfile.read())
+ pub_key = cast(Ed25519PublicKey, algo.from_jwk(keyfile.read()))
signature = algo.sign(b"Hello World!", priv_key)
assert algo.verify(b"Hello World!", pub_key, signature)
@@ -867,7 +881,7 @@ class TestOKPAlgorithms:
# Invalid instance type
with pytest.raises(InvalidKeyError):
- algo.from_jwk(123)
+ algo.from_jwk(123) # type: ignore[arg-type]
# Invalid JSON
with pytest.raises(InvalidKeyError):
@@ -913,15 +927,15 @@ class TestOKPAlgorithms:
algo = OKPAlgorithm()
with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile:
- priv_key_1 = algo.from_jwk(keyfile.read())
+ priv_key_1 = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read()))
with open(key_path("jwk_okp_pub_Ed25519.json")) as keyfile:
- pub_key_1 = algo.from_jwk(keyfile.read())
+ pub_key_1 = cast(Ed25519PublicKey, algo.from_jwk(keyfile.read()))
pub = algo.to_jwk(pub_key_1)
pub_key_2 = algo.from_jwk(pub)
pri = algo.to_jwk(priv_key_1)
- priv_key_2 = algo.from_jwk(pri)
+ priv_key_2 = cast(Ed25519PrivateKey, algo.from_jwk(pri))
signature_1 = algo.sign(b"Hello World!", priv_key_1)
signature_2 = algo.sign(b"Hello World!", priv_key_2)
@@ -932,13 +946,13 @@ class TestOKPAlgorithms:
algo = OKPAlgorithm()
with pytest.raises(InvalidKeyError):
- algo.to_jwk({"not": "a valid key"})
+ algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type]
def test_okp_ed448_jwk_private_key_should_parse_and_verify(self):
algo = OKPAlgorithm()
with open(key_path("jwk_okp_key_Ed448.json")) as keyfile:
- key = algo.from_jwk(keyfile.read())
+ key = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read()))
signature = algo.sign(b"Hello World!", key)
assert algo.verify(b"Hello World!", key.public_key(), signature)
@@ -949,7 +963,7 @@ class TestOKPAlgorithms:
algo = OKPAlgorithm()
with open(key_path("jwk_okp_key_Ed448.json")) as keyfile:
- key = algo.from_jwk(keyfile.read())
+ key = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read()))
signature = algo.sign(b"Hello World!", key)
assert algo.verify(b"Hello World!", key, signature)
@@ -958,10 +972,10 @@ class TestOKPAlgorithms:
algo = OKPAlgorithm()
with open(key_path("jwk_okp_key_Ed448.json")) as keyfile:
- priv_key = algo.from_jwk(keyfile.read())
+ priv_key = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read()))
with open(key_path("jwk_okp_pub_Ed448.json")) as keyfile:
- pub_key = algo.from_jwk(keyfile.read())
+ pub_key = cast(Ed448PublicKey, algo.from_jwk(keyfile.read()))
signature = algo.sign(b"Hello World!", priv_key)
assert algo.verify(b"Hello World!", pub_key, signature)
@@ -976,7 +990,7 @@ class TestOKPAlgorithms:
# Invalid instance type
with pytest.raises(InvalidKeyError):
- algo.from_jwk(123)
+ algo.from_jwk(123) # type: ignore[arg-type]
# Invalid JSON
with pytest.raises(InvalidKeyError):
@@ -1022,15 +1036,15 @@ class TestOKPAlgorithms:
algo = OKPAlgorithm()
with open(key_path("jwk_okp_key_Ed448.json")) as keyfile:
- priv_key_1 = algo.from_jwk(keyfile.read())
+ priv_key_1 = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read()))
with open(key_path("jwk_okp_pub_Ed448.json")) as keyfile:
- pub_key_1 = algo.from_jwk(keyfile.read())
+ pub_key_1 = cast(Ed448PublicKey, algo.from_jwk(keyfile.read()))
pub = algo.to_jwk(pub_key_1)
pub_key_2 = algo.from_jwk(pub)
pri = algo.to_jwk(priv_key_1)
- priv_key_2 = algo.from_jwk(pri)
+ priv_key_2 = cast(Ed448PrivateKey, algo.from_jwk(pri))
signature_1 = algo.sign(b"Hello World!", priv_key_1)
signature_2 = algo.sign(b"Hello World!", priv_key_2)
diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py
index d2aa915..3385716 100644
--- a/tests/test_api_jws.py
+++ b/tests/test_api_jws.py
@@ -82,7 +82,7 @@ class TestJWS:
assert jws.options["verify_signature"]
- def test_options_must_be_dict(self, jws):
+ def test_options_must_be_dict(self):
pytest.raises(TypeError, PyJWS, options=object())
pytest.raises((TypeError, ValueError), PyJWS, options=("something"))
@@ -533,11 +533,11 @@ class TestJWS:
# string-formatted key
with open(key_path("testkey_rsa.priv")) as rsa_priv_file:
- priv_rsakey = rsa_priv_file.read()
+ priv_rsakey = rsa_priv_file.read() # type: ignore[assignment]
jws_message = jws.encode(payload, priv_rsakey, algorithm=algo)
with open(key_path("testkey_rsa.pub")) as rsa_pub_file:
- pub_rsakey = rsa_pub_file.read()
+ pub_rsakey = rsa_pub_file.read() # type: ignore[assignment]
jws.decode(jws_message, pub_rsakey, algorithms=[algo])
def test_rsa_related_algorithms(self, jws):
@@ -582,11 +582,11 @@ class TestJWS:
# string-formatted key
with open(key_path("testkey_ec.priv")) as ec_priv_file:
- priv_eckey = ec_priv_file.read()
+ priv_eckey = ec_priv_file.read() # type: ignore[assignment]
jws_message = jws.encode(payload, priv_eckey, algorithm=algo)
with open(key_path("testkey_ec.pub")) as ec_pub_file:
- pub_eckey = ec_pub_file.read()
+ pub_eckey = ec_pub_file.read() # type: ignore[assignment]
jws.decode(jws_message, pub_eckey, algorithms=[algo])
def test_ecdsa_related_algorithms(self, jws):
diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py
index 3c08ea5..0d53444 100644
--- a/tests/test_api_jwt.py
+++ b/tests/test_api_jwt.py
@@ -365,8 +365,8 @@ class TestJWT:
secret = "secret"
jwt_message = jwt.encode(payload, secret)
- # With 3 seconds leeway, should be ok
- for leeway in (3, timedelta(seconds=3)):
+ # With 5 seconds leeway, should be ok
+ for leeway in (5, timedelta(seconds=5)):
decoded = jwt.decode(
jwt_message, secret, leeway=leeway, algorithms=["HS256"]
)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index a089f86..122dcb4 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -36,4 +36,4 @@ def test_from_base64url_uint(inputval, expected):
def test_force_bytes_raises_error_on_invalid_object():
with pytest.raises(TypeError):
- force_bytes({})
+ force_bytes({}) # type: ignore[arg-type]