From 56b3d5633160e79e1f4c5c09023d68759cbf84a6 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sun, 16 Apr 2023 09:51:09 +0200 Subject: Add complete types to take all allowed keys into account (#873) * Use new style typing * Fix type annotations to allow all keys * Use string type annotations where required * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove outdated comment * Ignore `if TYPE_CHECKING:` lines in coverage * Remove duplicate test * Fix mypy errors * Update algorithms.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fully switch to modern annotations * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update `pre-commit` mypy config * Use Python 3.11 for mypy * Update mypy Python version in `pyproject.toml` * Few tests mypy fixes * fix mypy errors on tests * Fix key imports * Remove unused import * Fix randomly failing test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Asif Saif Uddin --- .pre-commit-config.yaml | 1 + jwt/algorithms.py | 189 +++++++++++++++++++++++++++-------------------- jwt/api_jwk.py | 12 +-- jwt/api_jws.py | 23 +++--- jwt/api_jwt.py | 67 +++++++++-------- jwt/utils.py | 8 +- pyproject.toml | 3 +- tests/test_algorithms.py | 116 ++++++++++++++++------------- tests/test_api_jws.py | 10 +-- tests/test_api_jwt.py | 4 +- tests/test_utils.py | 2 +- 11 files changed, 241 insertions(+), 194 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3af45df..25996a7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,3 +39,4 @@ repos: rev: "v1.1.1" hooks: - id: mypy + additional_dependencies: [cryptography>=3.4.0] diff --git a/jwt/algorithms.py b/jwt/algorithms.py index bc928fe..4bcf81b 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import hashlib import hmac import json from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, Type, Union +from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, cast from .exceptions import InvalidKeyError from .types import HashlibHash, JWKDict @@ -19,7 +21,6 @@ from .utils import ( ) try: - import cryptography.exceptions from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes @@ -68,6 +69,23 @@ try: except ModuleNotFoundError: has_crypto = False + +if TYPE_CHECKING: + # Type aliases for convenience in algorithms method signatures + AllowedRSAKeys = RSAPrivateKey | RSAPublicKey + AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey + AllowedOKPKeys = ( + Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey + ) + AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys + AllowedPrivateKeys = ( + RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey + ) + AllowedPublicKeys = ( + RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey + ) + + requires_cryptography = { "RS256", "RS384", @@ -84,7 +102,7 @@ requires_cryptography = { } -def get_default_algorithms() -> Dict[str, "Algorithm"]: +def get_default_algorithms() -> dict[str, Algorithm]: """ Returns the algorithms that are implemented by the library. """ @@ -145,10 +163,6 @@ class Algorithm(ABC): else: return bytes(hash_alg(bytestr).digest()) - # TODO: all key-related `Any`s in this class should optimally be made - # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605 - # that may still be poorly supported. - @abstractmethod def prepare_key(self, key: Any) -> Any: """ @@ -172,16 +186,16 @@ class Algorithm(ABC): @staticmethod @abstractmethod - def to_jwk(key_obj) -> JWKDict: + def to_jwk(key_obj) -> str: """ - Serializes a given RSA key into a JWK + Serializes a given key into a JWK """ @staticmethod @abstractmethod - def from_jwk(jwk: JWKDict): + def from_jwk(jwk: str | JWKDict) -> Any: """ - Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object + Deserializes a given key from JWK back into a key object """ @@ -191,7 +205,7 @@ class NoneAlgorithm(Algorithm): operations are required. """ - def prepare_key(self, key): + def prepare_key(self, key: str | None) -> None: if key == "": key = None @@ -200,18 +214,18 @@ class NoneAlgorithm(Algorithm): return key - def sign(self, msg, key): + def sign(self, msg: bytes, key: None) -> bytes: return b"" - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: None, sig: bytes) -> bool: return False @staticmethod - def to_jwk(key_obj) -> JWKDict: + def to_jwk(key_obj: Any) -> NoReturn: raise NotImplementedError() @staticmethod - def from_jwk(jwk: JWKDict): + def from_jwk(jwk: str | JWKDict) -> NoReturn: raise NotImplementedError() @@ -228,19 +242,19 @@ class HMACAlgorithm(Algorithm): def __init__(self, hash_alg: HashlibHash) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): - key = force_bytes(key) + def prepare_key(self, key: str | bytes) -> bytes: + key_bytes = force_bytes(key) - if is_pem_format(key) or is_ssh_key(key): + if is_pem_format(key_bytes) or is_ssh_key(key_bytes): raise InvalidKeyError( "The specified key is an asymmetric key or x509 certificate and" " should not be used as an HMAC secret." ) - return key + return key_bytes @staticmethod - def to_jwk(key_obj): + def to_jwk(key_obj: str | bytes) -> str: return json.dumps( { "k": base64url_encode(force_bytes(key_obj)).decode(), @@ -249,10 +263,10 @@ class HMACAlgorithm(Algorithm): ) @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: str | JWKDict) -> bytes: try: if isinstance(jwk, str): - obj = json.loads(jwk) + obj: JWKDict = json.loads(jwk) elif isinstance(jwk, dict): obj = jwk else: @@ -268,7 +282,7 @@ class HMACAlgorithm(Algorithm): def sign(self, msg: bytes, key: bytes) -> bytes: return hmac.new(key, msg, self.hash_alg).digest() - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: return hmac.compare_digest(sig, self.sign(msg, key)) @@ -280,14 +294,14 @@ if has_crypto: RSASSA-PKCS-v1_5 and the specified hash function. """ - SHA256: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA256 - SHA384: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA384 - SHA512: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA512 + SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 + SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 + SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 - def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None: + def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): + def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: if isinstance(key, (RSAPrivateKey, RSAPublicKey)): return key @@ -298,15 +312,17 @@ if has_crypto: try: if key_bytes.startswith(b"ssh-rsa"): - return load_ssh_public_key(key_bytes) + return cast(RSAPublicKey, load_ssh_public_key(key_bytes)) else: - return load_pem_private_key(key_bytes, password=None) + return cast( + RSAPrivateKey, load_pem_private_key(key_bytes, password=None) + ) except ValueError: - return load_pem_public_key(key_bytes) + return cast(RSAPublicKey, load_pem_public_key(key_bytes)) @staticmethod - def to_jwk(key_obj): - obj = None + def to_jwk(key_obj: AllowedRSAKeys) -> str: + obj: dict[str, Any] | None = None if hasattr(key_obj, "private_numbers"): # Private key @@ -341,7 +357,7 @@ if has_crypto: return json.dumps(obj) @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -412,10 +428,10 @@ if has_crypto: else: raise InvalidKeyError("Not a public or private key") - def sign(self, msg, key): + def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: try: key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) return True @@ -428,14 +444,14 @@ if has_crypto: ECDSA and the specified hash function """ - SHA256: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA256 - SHA384: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA384 - SHA512: ClassVar[Type[hashes.HashAlgorithm]] = hashes.SHA512 + SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 + SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 + SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 - def __init__(self, hash_alg: Type[hashes.HashAlgorithm]) -> None: + def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): + def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): return key @@ -449,41 +465,46 @@ if has_crypto: # the Verifying Key first. try: if key_bytes.startswith(b"ecdsa-sha2-"): - key = load_ssh_public_key(key_bytes) + crypto_key = load_ssh_public_key(key_bytes) else: - key = load_pem_public_key(key_bytes) + crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment] except ValueError: - key = load_pem_private_key(key_bytes, password=None) + crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] # Explicit check the key to prevent confusing errors from cryptography - if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): + if not isinstance( + crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey) + ): raise InvalidKeyError( "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" ) - return key + return crypto_key - def sign(self, msg, key): + def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: der_sig = key.sign(msg, ECDSA(self.hash_alg())) return der_to_raw_signature(der_sig, key.curve) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool: try: der_sig = raw_to_der_signature(sig, key.curve) except ValueError: return False try: - if isinstance(key, EllipticCurvePrivateKey): - key = key.public_key() - key.verify(der_sig, msg, ECDSA(self.hash_alg())) + public_key = ( + key.public_key() + if isinstance(key, EllipticCurvePrivateKey) + else key + ) + public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) return True except InvalidSignature: return False @staticmethod - def to_jwk(key_obj): + def to_jwk(key_obj: AllowedECKeys) -> str: if isinstance(key_obj, EllipticCurvePrivateKey): public_numbers = key_obj.public_key().public_numbers() elif isinstance(key_obj, EllipticCurvePublicKey): @@ -502,7 +523,7 @@ if has_crypto: else: raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") - obj = { + obj: dict[str, Any] = { "kty": "EC", "crv": crv, "x": to_base64url_uint(public_numbers.x).decode(), @@ -517,9 +538,7 @@ if has_crypto: return json.dumps(obj) @staticmethod - def from_jwk( - jwk: Any, - ) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]: + def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -591,7 +610,7 @@ if has_crypto: Performs a signature using RSASSA-PSS with MGF1 """ - def sign(self, msg, key): + def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: return key.sign( msg, padding.PSS( @@ -601,7 +620,7 @@ if has_crypto: self.hash_alg(), ) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: try: key.verify( sig, @@ -623,21 +642,20 @@ if has_crypto: This class requires ``cryptography>=2.6`` to be installed. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: pass - def prepare_key(self, key): + def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: if isinstance(key, (bytes, str)): - if isinstance(key, str): - key = key.encode("utf-8") - str_key = key.decode("utf-8") + key_str = key.decode("utf-8") if isinstance(key, bytes) else key + key_bytes = key.encode("utf-8") if isinstance(key, str) else key - if "-----BEGIN PUBLIC" in str_key: - key = load_pem_public_key(key) - elif "-----BEGIN PRIVATE" in str_key: - key = load_pem_private_key(key, password=None) - elif str_key[0:4] == "ssh-": - key = load_ssh_public_key(key) + if "-----BEGIN PUBLIC" in key_str: + key = load_pem_public_key(key_bytes) # type: ignore[assignment] + elif "-----BEGIN PRIVATE" in key_str: + key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] + elif key_str[0:4] == "ssh-": + key = load_ssh_public_key(key_bytes) # type: ignore[assignment] # Explicit check the key to prevent confusing errors from cryptography if not isinstance( @@ -650,7 +668,9 @@ if has_crypto: return key - def sign(self, msg, key): + def sign( + self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey + ) -> bytes: """ Sign a message ``msg`` using the EdDSA private key ``key`` :param str|bytes msg: Message to sign @@ -658,10 +678,12 @@ if has_crypto: or :class:`.Ed448PrivateKey` isinstance :return bytes signature: The signature, as bytes """ - msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg - return key.sign(msg) + msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg + return key.sign(msg_bytes) - def verify(self, msg, key, sig): + def verify( + self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes + ) -> bool: """ Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` @@ -672,18 +694,21 @@ if has_crypto: :return bool verified: True if signature is valid, False if not. """ try: - msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg - sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig + msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg + sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig - if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): - key = key.public_key() - key.verify(sig, msg) + public_key = ( + key.public_key() + if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) + else key + ) + public_key.verify(sig_bytes, msg_bytes) return True # If no exception was raised, the signature is valid. - except cryptography.exceptions.InvalidSignature: + except InvalidSignature: return False @staticmethod - def to_jwk(key): + def to_jwk(key: AllowedOKPKeys) -> str: if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): x = key.public_bytes( encoding=Encoding.Raw, @@ -723,7 +748,7 @@ if has_crypto: raise InvalidKeyError("Not a public or private key") @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index fdcde21..dd876f7 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -2,7 +2,7 @@ from __future__ import annotations import json import time -from typing import Any, Optional +from typing import Any from .algorithms import get_default_algorithms, has_crypto, requires_cryptography from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError @@ -10,7 +10,7 @@ from .types import JWKDict class PyJWK: - def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None: + def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None: self._algorithms = get_default_algorithms() self._jwk_data = jwk_data @@ -60,7 +60,7 @@ class PyJWK: self.key = self.Algorithm.from_jwk(self._jwk_data) @staticmethod - def from_dict(obj: JWKDict, algorithm: Optional[str] = None) -> "PyJWK": + def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK": return PyJWK(obj, algorithm) @staticmethod @@ -69,15 +69,15 @@ class PyJWK: return PyJWK.from_dict(obj, algorithm) @property - def key_type(self) -> Optional[str]: + def key_type(self) -> str | None: return self._jwk_data.get("kty", None) @property - def key_id(self) -> Optional[str]: + def key_id(self) -> str | None: return self._jwk_data.get("kid", None) @property - def public_key_use(self) -> Optional[str]: + def public_key_use(self) -> str | None: return self._jwk_data.get("use", None) diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 0a775d7..fa6708c 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -3,7 +3,7 @@ from __future__ import annotations import binascii import json import warnings -from typing import Any, List, Optional, Type +from typing import TYPE_CHECKING, Any from .algorithms import ( Algorithm, @@ -20,14 +20,17 @@ from .exceptions import ( from .utils import base64url_decode, base64url_encode from .warnings import RemovedInPyjwt3Warning +if TYPE_CHECKING: + from .algorithms import AllowedPrivateKeys, AllowedPublicKeys + class PyJWS: header_typ = "JWT" def __init__( self, - algorithms: Optional[List[str]] = None, - options: Optional[dict[str, Any]] = None, + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, ) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( @@ -100,10 +103,10 @@ class PyJWS: def encode( self, payload: bytes, - key: str, + key: AllowedPrivateKeys | str | bytes, algorithm: str | None = "HS256", headers: dict[str, Any] | None = None, - json_encoder: Type[json.JSONEncoder] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, sort_headers: bool = True, ) -> str: @@ -169,7 +172,7 @@ class PyJWS: def decode_complete( self, jwt: str | bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -214,7 +217,7 @@ class PyJWS: def decode( self, jwt: str | bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -286,7 +289,7 @@ class PyJWS: signing_input: bytes, header: dict[str, Any], signature: bytes, - key: str | bytes = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, ) -> None: try: @@ -301,9 +304,9 @@ class PyJWS: alg_obj = self.get_algorithm_by_name(alg) except NotImplementedError as e: raise InvalidAlgorithmError("Algorithm not supported") from e - key = alg_obj.prepare_key(key) + prepared_key = alg_obj.prepare_key(key) - if not alg_obj.verify(signing_input, key, signature): + if not alg_obj.verify(signing_input, prepared_key, signature): raise InvalidSignatureError("Signature verification failed") def _validate_headers(self, headers: dict[str, Any]) -> None: diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index d85f6e8..49d1b48 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -5,7 +5,7 @@ import warnings from calendar import timegm from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any from . import api_jws from .exceptions import ( @@ -19,15 +19,18 @@ from .exceptions import ( ) from .warnings import RemovedInPyjwt3Warning +if TYPE_CHECKING: + from .algorithms import AllowedPrivateKeys, AllowedPublicKeys + class PyJWT: - def __init__(self, options: Optional[dict[str, Any]] = None) -> None: + def __init__(self, options: dict[str, Any] | None = None) -> None: if options is None: options = {} self.options: dict[str, Any] = {**self._get_default_options(), **options} @staticmethod - def _get_default_options() -> Dict[str, Union[bool, List[str]]]: + def _get_default_options() -> dict[str, bool | list[str]]: return { "verify_signature": True, "verify_exp": True, @@ -40,11 +43,11 @@ class PyJWT: def encode( self, - payload: Dict[str, Any], - key: str, - algorithm: Optional[str] = "HS256", - headers: Optional[Dict[str, Any]] = None, - json_encoder: Optional[Type[json.JSONEncoder]] = None, + payload: dict[str, Any], + key: AllowedPrivateKeys | str | bytes, + algorithm: str | None = "HS256", + headers: dict[str, Any] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, sort_headers: bool = True, ) -> str: # Check that we get a dict @@ -78,9 +81,9 @@ class PyJWT: def _encode_payload( self, - payload: Dict[str, Any], - headers: Optional[Dict[str, Any]] = None, - json_encoder: Optional[Type[json.JSONEncoder]] = None, + payload: dict[str, Any], + headers: dict[str, Any] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, ) -> bytes: """ Encode a given payload to the bytes to be signed. @@ -97,21 +100,21 @@ class PyJWT: def decode_complete( self, jwt: str | bytes, - key: str | bytes = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict[str, Any]] = None, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 - verify: Optional[bool] = None, + verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 - detached_payload: Optional[bytes] = None, + detached_payload: bytes | None = None, # passthrough arguments to _validate_claims # consider putting in options - audience: Optional[Union[str, Iterable[str]]] = None, - issuer: Optional[str] = None, - leeway: Union[int, float, timedelta] = 0, + audience: str | Iterable[str] | None = None, + issuer: str | None = None, + leeway: float | timedelta = 0, # kwargs - **kwargs, - ) -> Dict[str, Any]: + **kwargs: Any, + ) -> dict[str, Any]: if kwargs: warnings.warn( "passing additional kwargs to decode_complete() is deprecated " @@ -163,7 +166,7 @@ class PyJWT: decoded["payload"] = payload return decoded - def _decode_payload(self, decoded: Dict[str, Any]) -> Any: + def _decode_payload(self, decoded: dict[str, Any]) -> Any: """ Decode the payload from a JWS dictionary (payload, signature, header). @@ -182,20 +185,20 @@ class PyJWT: def decode( self, jwt: str | bytes, - key: str | bytes = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict[str, Any]] = None, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 - verify: Optional[bool] = None, + verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 - detached_payload: Optional[bytes] = None, + detached_payload: bytes | None = None, # passthrough arguments to _validate_claims # consider putting in options - audience: Optional[Union[str, Iterable[str]]] = None, - issuer: Optional[str] = None, - leeway: Union[int, float, timedelta] = 0, + audience: str | Iterable[str] | None = None, + issuer: str | None = None, + leeway: float | timedelta = 0, # kwargs - **kwargs, + **kwargs: Any, ) -> Any: if kwargs: warnings.warn( @@ -303,7 +306,7 @@ class PyJWT: def _validate_aud( self, payload: dict[str, Any], - audience: Optional[Union[str, Iterable[str]]], + audience: str | Iterable[str] | None, ) -> None: if audience is None: if "aud" not in payload or not payload["aud"]: diff --git a/jwt/utils.py b/jwt/utils.py index 92e3c8b..81c5ee4 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -1,7 +1,7 @@ import base64 import binascii import re -from typing import Any, AnyStr +from typing import Union try: from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve @@ -13,7 +13,7 @@ except ModuleNotFoundError: pass -def force_bytes(value: Any) -> bytes: +def force_bytes(value: Union[bytes, str]) -> bytes: if isinstance(value, str): return value.encode("utf-8") elif isinstance(value, bytes): @@ -22,7 +22,7 @@ def force_bytes(value: Any) -> bytes: raise TypeError("Expected a string value") -def base64url_decode(input: AnyStr) -> bytes: +def base64url_decode(input: Union[bytes, str]) -> bytes: input_bytes = force_bytes(input) rem = len(input_bytes) % 4 @@ -49,7 +49,7 @@ def to_base64url_uint(val: int) -> bytes: return base64url_encode(int_bytes) -def from_base64url_uint(val: AnyStr) -> int: +def from_base64url_uint(val: Union[bytes, str]) -> int: data = base64url_decode(force_bytes(val)) return int.from_bytes(data, byteorder="big") diff --git a/pyproject.toml b/pyproject.toml index e7920ff..40dfb7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ source = ["jwt", ".tox/*/site-packages"] [tool.coverage.report] show_missing = true +exclude_lines = ["if TYPE_CHECKING:"] [tool.isort] profile = "black" @@ -19,7 +20,7 @@ atomic = true combine_as_imports = true [tool.mypy] -python_version = 3.7 +python_version = 3.11 ignore_missing_imports = true warn_unused_ignores = true no_implicit_optional = true diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 8aa9ad7..7a90376 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,5 +1,6 @@ import base64 import json +from typing import cast import pytest @@ -11,6 +12,23 @@ from .keys import load_ec_pub_key_p_521, load_hmac_key, load_rsa_pub_key from .utils import crypto_required, key_path if has_crypto: + from cryptography.hazmat.primitives.asymmetric.ec import ( + EllipticCurvePrivateKey, + EllipticCurvePublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.ed448 import ( + Ed448PrivateKey, + Ed448PublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.ed25519 import ( + Ed25519PrivateKey, + Ed25519PublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateKey, + RSAPublicKey, + ) + from jwt.algorithms import ECAlgorithm, OKPAlgorithm, RSAAlgorithm, RSAPSSAlgorithm @@ -37,7 +55,7 @@ class TestAlgorithms: algo = HMACAlgorithm(HMACAlgorithm.SHA256) with pytest.raises(TypeError) as context: - algo.prepare_key(object()) + algo.prepare_key(object()) # type: ignore[arg-type] exception = context.value assert str(exception) == "Expected a string value" @@ -112,7 +130,7 @@ class TestAlgorithms: algo = RSAAlgorithm(RSAAlgorithm.SHA256) with pytest.raises(TypeError): - algo.prepare_key(None) + algo.prepare_key(None) # type: ignore[arg-type] @crypto_required def test_rsa_verify_should_return_false_if_signature_invalid(self): @@ -132,7 +150,7 @@ class TestAlgorithms: sig += b"123" # Signature is now invalid with open(key_path("testkey_rsa.pub")) as keyfile: - pub_key = algo.prepare_key(keyfile.read()) + pub_key = cast(RSAPublicKey, algo.prepare_key(keyfile.read())) result = algo.verify(message, pub_key, sig) assert not result @@ -149,10 +167,10 @@ class TestAlgorithms: algo = ECAlgorithm(hash) with open(key_path(f"jwk_ec_pub_{curve}.json")) as keyfile: - pub_key = algo.from_jwk(keyfile.read()) + pub_key = cast(EllipticCurvePublicKey, algo.from_jwk(keyfile.read())) with open(key_path(f"jwk_ec_key_{curve}.json")) as keyfile: - priv_key = algo.from_jwk(keyfile.read()) + priv_key = cast(EllipticCurvePrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", priv_key) assert algo.verify(b"Hello World!", pub_key, signature) @@ -223,9 +241,9 @@ class TestAlgorithms: algo = ECAlgorithm(ECAlgorithm.SHA256) with open(key_path("testkey_ec.priv")) as ec_key: - orig_key = algo.prepare_key(ec_key.read()) + orig_key = cast(EllipticCurvePrivateKey, algo.prepare_key(ec_key.read())) - parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + parsed_key = cast(EllipticCurvePrivateKey, algo.from_jwk(algo.to_jwk(orig_key))) assert parsed_key.private_numbers() == orig_key.private_numbers() assert ( parsed_key.private_numbers().public_numbers @@ -237,9 +255,9 @@ class TestAlgorithms: algo = ECAlgorithm(ECAlgorithm.SHA256) with open(key_path("testkey_ec.pub")) as ec_key: - orig_key = algo.prepare_key(ec_key.read()) + orig_key = cast(EllipticCurvePublicKey, algo.prepare_key(ec_key.read())) - parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + parsed_key = cast(EllipticCurvePublicKey, algo.from_jwk(algo.to_jwk(orig_key))) assert parsed_key.public_numbers() == orig_key.public_numbers() @crypto_required @@ -284,7 +302,7 @@ class TestAlgorithms: algo = ECAlgorithm(ECAlgorithm.SHA256) with pytest.raises(InvalidKeyError): - algo.to_jwk({"not": "a valid key"}) + algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type] @crypto_required def test_ec_to_jwk_with_valid_curves(self): @@ -320,10 +338,10 @@ class TestAlgorithms: algo = RSAAlgorithm(RSAAlgorithm.SHA256) with open(key_path("jwk_rsa_pub.json")) as keyfile: - pub_key = algo.from_jwk(keyfile.read()) + pub_key = cast(RSAPublicKey, algo.from_jwk(keyfile.read())) with open(key_path("jwk_rsa_key.json")) as keyfile: - priv_key = algo.from_jwk(keyfile.read()) + priv_key = cast(RSAPrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", priv_key) assert algo.verify(b"Hello World!", pub_key, signature) @@ -333,9 +351,9 @@ class TestAlgorithms: algo = RSAAlgorithm(RSAAlgorithm.SHA256) with open(key_path("testkey_rsa.priv")) as rsa_key: - orig_key = algo.prepare_key(rsa_key.read()) + orig_key = cast(RSAPrivateKey, algo.prepare_key(rsa_key.read())) - parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + parsed_key = cast(RSAPrivateKey, algo.from_jwk(algo.to_jwk(orig_key))) assert parsed_key.private_numbers() == orig_key.private_numbers() assert ( parsed_key.private_numbers().public_numbers @@ -347,9 +365,9 @@ class TestAlgorithms: algo = RSAAlgorithm(RSAAlgorithm.SHA256) with open(key_path("testkey_rsa.pub")) as rsa_key: - orig_key = algo.prepare_key(rsa_key.read()) + orig_key = cast(RSAPublicKey, algo.prepare_key(rsa_key.read())) - parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + parsed_key = cast(RSAPublicKey, algo.from_jwk(algo.to_jwk(orig_key))) assert parsed_key.public_numbers() == orig_key.public_numbers() @crypto_required @@ -380,14 +398,16 @@ class TestAlgorithms: with open(key_path("jwk_rsa_key.json")) as keyfile: keybytes = keyfile.read() - control_key = algo.from_jwk(keybytes).private_numbers() + control_key = cast(RSAPrivateKey, algo.from_jwk(keybytes)).private_numbers() keydata = json.loads(keybytes) delete_these = ["p", "q", "dp", "dq", "qi"] for field in delete_these: del keydata[field] - parsed_key = algo.from_jwk(json.dumps(keydata)).private_numbers() + parsed_key = cast( + RSAPrivateKey, algo.from_jwk(json.dumps(keydata)) + ).private_numbers() assert control_key.d == parsed_key.d assert control_key.p == parsed_key.p @@ -505,7 +525,7 @@ class TestAlgorithms: algo = RSAAlgorithm(RSAAlgorithm.SHA256) with pytest.raises(InvalidKeyError): - algo.to_jwk({"not": "a valid key"}) + algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type] @crypto_required def test_rsa_from_jwk_raises_exception_on_invalid_key(self): @@ -520,7 +540,7 @@ class TestAlgorithms: algo = ECAlgorithm(ECAlgorithm.SHA256) with pytest.raises(TypeError): - algo.prepare_key(None) + algo.prepare_key(None) # type: ignore[arg-type] @crypto_required def test_ec_should_accept_pem_private_key_bytes(self): @@ -590,11 +610,11 @@ class TestAlgorithms: message = b"Hello World!" with open(key_path("testkey_rsa.priv")) as keyfile: - priv_key = algo.prepare_key(keyfile.read()) + priv_key = cast(RSAPrivateKey, algo.prepare_key(keyfile.read())) sig = algo.sign(message, priv_key) with open(key_path("testkey_rsa.pub")) as keyfile: - pub_key = algo.prepare_key(keyfile.read()) + pub_key = cast(RSAPublicKey, algo.prepare_key(keyfile.read())) result = algo.verify(message, pub_key, sig) assert result @@ -617,7 +637,7 @@ class TestAlgorithms: jwt_sig += b"123" # Signature is now invalid with open(key_path("testkey_rsa.pub")) as keyfile: - jwt_pub_key = algo.prepare_key(keyfile.read()) + jwt_pub_key = cast(RSAPublicKey, algo.prepare_key(keyfile.read())) result = algo.verify(jwt_message, jwt_pub_key, jwt_sig) assert not result @@ -678,7 +698,7 @@ class TestAlgorithmsRFC7520: ) algo = RSAAlgorithm(RSAAlgorithm.SHA256) - key = algo.prepare_key(load_rsa_pub_key()) + key = cast(RSAPublicKey, algo.prepare_key(load_rsa_pub_key())) result = algo.verify(signing_input, key, signature) assert result @@ -709,7 +729,7 @@ class TestAlgorithmsRFC7520: ) algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384) - key = algo.prepare_key(load_rsa_pub_key()) + key = cast(RSAPublicKey, algo.prepare_key(load_rsa_pub_key())) result = algo.verify(signing_input, key, signature) assert result @@ -759,7 +779,7 @@ class TestOKPAlgorithms: algo = OKPAlgorithm() with pytest.raises(InvalidKeyError): - algo.prepare_key(None) + algo.prepare_key(None) # type: ignore[arg-type] with open(key_path("testkey_ed25519")) as keyfile: algo.prepare_key(keyfile.read()) @@ -767,12 +787,6 @@ class TestOKPAlgorithms: with open(key_path("testkey_ed25519.pub")) as keyfile: algo.prepare_key(keyfile.read()) - def test_okp_ed25519_should_accept_unicode_key(self): - algo = OKPAlgorithm() - - with open(key_path("testkey_ed25519")) as ec_key: - algo.prepare_key(ec_key.read()) - def test_okp_ed25519_sign_should_generate_correct_signature_value(self): algo = OKPAlgorithm() @@ -781,10 +795,10 @@ class TestOKPAlgorithms: expected_sig = base64.b64decode(self.hello_world_sig) with open(key_path("testkey_ed25519")) as keyfile: - jwt_key = algo.prepare_key(keyfile.read()) + jwt_key = cast(Ed25519PrivateKey, algo.prepare_key(keyfile.read())) with open(key_path("testkey_ed25519.pub")) as keyfile: - jwt_pub_key = algo.prepare_key(keyfile.read()) + jwt_pub_key = cast(Ed25519PublicKey, algo.prepare_key(keyfile.read())) algo.sign(jwt_message, jwt_key) result = algo.verify(jwt_message, jwt_pub_key, expected_sig) @@ -829,7 +843,7 @@ class TestOKPAlgorithms: algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile: - key = algo.from_jwk(keyfile.read()) + key = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", key) assert algo.verify(b"Hello World!", key.public_key(), signature) @@ -840,7 +854,7 @@ class TestOKPAlgorithms: algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile: - key = algo.from_jwk(keyfile.read()) + key = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", key) assert algo.verify(b"Hello World!", key, signature) @@ -849,10 +863,10 @@ class TestOKPAlgorithms: algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile: - priv_key = algo.from_jwk(keyfile.read()) + priv_key = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read())) with open(key_path("jwk_okp_pub_Ed25519.json")) as keyfile: - pub_key = algo.from_jwk(keyfile.read()) + pub_key = cast(Ed25519PublicKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", priv_key) assert algo.verify(b"Hello World!", pub_key, signature) @@ -867,7 +881,7 @@ class TestOKPAlgorithms: # Invalid instance type with pytest.raises(InvalidKeyError): - algo.from_jwk(123) + algo.from_jwk(123) # type: ignore[arg-type] # Invalid JSON with pytest.raises(InvalidKeyError): @@ -913,15 +927,15 @@ class TestOKPAlgorithms: algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed25519.json")) as keyfile: - priv_key_1 = algo.from_jwk(keyfile.read()) + priv_key_1 = cast(Ed25519PrivateKey, algo.from_jwk(keyfile.read())) with open(key_path("jwk_okp_pub_Ed25519.json")) as keyfile: - pub_key_1 = algo.from_jwk(keyfile.read()) + pub_key_1 = cast(Ed25519PublicKey, algo.from_jwk(keyfile.read())) pub = algo.to_jwk(pub_key_1) pub_key_2 = algo.from_jwk(pub) pri = algo.to_jwk(priv_key_1) - priv_key_2 = algo.from_jwk(pri) + priv_key_2 = cast(Ed25519PrivateKey, algo.from_jwk(pri)) signature_1 = algo.sign(b"Hello World!", priv_key_1) signature_2 = algo.sign(b"Hello World!", priv_key_2) @@ -932,13 +946,13 @@ class TestOKPAlgorithms: algo = OKPAlgorithm() with pytest.raises(InvalidKeyError): - algo.to_jwk({"not": "a valid key"}) + algo.to_jwk({"not": "a valid key"}) # type: ignore[arg-type] def test_okp_ed448_jwk_private_key_should_parse_and_verify(self): algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed448.json")) as keyfile: - key = algo.from_jwk(keyfile.read()) + key = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", key) assert algo.verify(b"Hello World!", key.public_key(), signature) @@ -949,7 +963,7 @@ class TestOKPAlgorithms: algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed448.json")) as keyfile: - key = algo.from_jwk(keyfile.read()) + key = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", key) assert algo.verify(b"Hello World!", key, signature) @@ -958,10 +972,10 @@ class TestOKPAlgorithms: algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed448.json")) as keyfile: - priv_key = algo.from_jwk(keyfile.read()) + priv_key = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read())) with open(key_path("jwk_okp_pub_Ed448.json")) as keyfile: - pub_key = algo.from_jwk(keyfile.read()) + pub_key = cast(Ed448PublicKey, algo.from_jwk(keyfile.read())) signature = algo.sign(b"Hello World!", priv_key) assert algo.verify(b"Hello World!", pub_key, signature) @@ -976,7 +990,7 @@ class TestOKPAlgorithms: # Invalid instance type with pytest.raises(InvalidKeyError): - algo.from_jwk(123) + algo.from_jwk(123) # type: ignore[arg-type] # Invalid JSON with pytest.raises(InvalidKeyError): @@ -1022,15 +1036,15 @@ class TestOKPAlgorithms: algo = OKPAlgorithm() with open(key_path("jwk_okp_key_Ed448.json")) as keyfile: - priv_key_1 = algo.from_jwk(keyfile.read()) + priv_key_1 = cast(Ed448PrivateKey, algo.from_jwk(keyfile.read())) with open(key_path("jwk_okp_pub_Ed448.json")) as keyfile: - pub_key_1 = algo.from_jwk(keyfile.read()) + pub_key_1 = cast(Ed448PublicKey, algo.from_jwk(keyfile.read())) pub = algo.to_jwk(pub_key_1) pub_key_2 = algo.from_jwk(pub) pri = algo.to_jwk(priv_key_1) - priv_key_2 = algo.from_jwk(pri) + priv_key_2 = cast(Ed448PrivateKey, algo.from_jwk(pri)) signature_1 = algo.sign(b"Hello World!", priv_key_1) signature_2 = algo.sign(b"Hello World!", priv_key_2) diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index d2aa915..3385716 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -82,7 +82,7 @@ class TestJWS: assert jws.options["verify_signature"] - def test_options_must_be_dict(self, jws): + def test_options_must_be_dict(self): pytest.raises(TypeError, PyJWS, options=object()) pytest.raises((TypeError, ValueError), PyJWS, options=("something")) @@ -533,11 +533,11 @@ class TestJWS: # string-formatted key with open(key_path("testkey_rsa.priv")) as rsa_priv_file: - priv_rsakey = rsa_priv_file.read() + priv_rsakey = rsa_priv_file.read() # type: ignore[assignment] jws_message = jws.encode(payload, priv_rsakey, algorithm=algo) with open(key_path("testkey_rsa.pub")) as rsa_pub_file: - pub_rsakey = rsa_pub_file.read() + pub_rsakey = rsa_pub_file.read() # type: ignore[assignment] jws.decode(jws_message, pub_rsakey, algorithms=[algo]) def test_rsa_related_algorithms(self, jws): @@ -582,11 +582,11 @@ class TestJWS: # string-formatted key with open(key_path("testkey_ec.priv")) as ec_priv_file: - priv_eckey = ec_priv_file.read() + priv_eckey = ec_priv_file.read() # type: ignore[assignment] jws_message = jws.encode(payload, priv_eckey, algorithm=algo) with open(key_path("testkey_ec.pub")) as ec_pub_file: - pub_eckey = ec_pub_file.read() + pub_eckey = ec_pub_file.read() # type: ignore[assignment] jws.decode(jws_message, pub_eckey, algorithms=[algo]) def test_ecdsa_related_algorithms(self, jws): diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 3c08ea5..0d53444 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -365,8 +365,8 @@ class TestJWT: secret = "secret" jwt_message = jwt.encode(payload, secret) - # With 3 seconds leeway, should be ok - for leeway in (3, timedelta(seconds=3)): + # With 5 seconds leeway, should be ok + for leeway in (5, timedelta(seconds=5)): decoded = jwt.decode( jwt_message, secret, leeway=leeway, algorithms=["HS256"] ) diff --git a/tests/test_utils.py b/tests/test_utils.py index a089f86..122dcb4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -36,4 +36,4 @@ def test_from_base64url_uint(inputval, expected): def test_force_bytes_raises_error_on_invalid_object(): with pytest.raises(TypeError): - force_bytes({}) + force_bytes({}) # type: ignore[arg-type] -- cgit v1.2.1