From 2306f68b87c5f7eecbd838db900a2d3685b81b3c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 10 Dec 2022 17:44:04 +0200 Subject: Make mypy configuration stricter and improve typing (#830) * PyJWS._verify_signature: raise early KeyError if header is missing alg * Make Mypy configuration stricter * Improve typing in jwt.utils * Improve typing in jwt.help * Improve typing in jwt.exceptions * Improve typing in jwt.api_jwk * Improve typing in jwt.api_jws * Improve typing & clean up imports in jwt.algorithms * Correct JWS.decode rettype to any (payload could be something else) * Update typing in api_jwt * Improve typing in jwks_client * Improve typing in docs/conf.py * Fix (benign) mistyping in test_advisory * Fix misc type complaints in tests --- docs/conf.py | 4 +- jwt/algorithms.py | 99 ++++++++++++++++++++++++++++-------------------- jwt/api_jwk.py | 26 +++++++------ jwt/api_jws.py | 34 ++++++++++------- jwt/api_jwt.py | 58 +++++++++++++++++++++------- jwt/exceptions.py | 4 +- jwt/help.py | 8 ++-- jwt/jwks_client.py | 9 ++++- jwt/types.py | 3 ++ jwt/utils.py | 30 +++++++-------- pyproject.toml | 8 ++++ tests/test_advisory.py | 12 ++++-- tests/test_algorithms.py | 6 +-- tests/test_api_jwk.py | 2 +- 14 files changed, 187 insertions(+), 116 deletions(-) create mode 100644 jwt/types.py diff --git a/docs/conf.py b/docs/conf.py index 44b5103..938631f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,7 +4,7 @@ import re import sphinx_rtd_theme -def read(*parts): +def read(*parts) -> str: """ Build an absolute path from *parts* and and return the contents of the resulting file. Assume UTF-8 encoding. @@ -14,7 +14,7 @@ def read(*parts): return f.read() -def find_version(*file_paths): +def find_version(*file_paths) -> str: """ Build a path from *file_paths* and search for a ``__version__`` string inside. diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 4fae441..c809f15 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,8 +1,10 @@ import hashlib import hmac import json +from typing import Any, Dict, Union from .exceptions import InvalidKeyError +from .types import JWKDict from .utils import ( base64url_decode, base64url_encode, @@ -20,10 +22,18 @@ try: from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes - from cryptography.hazmat.primitives.asymmetric import ec, padding + from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric.ec import ( + ECDSA, + SECP256K1, + SECP256R1, + SECP384R1, + SECP521R1, + EllipticCurve, EllipticCurvePrivateKey, + EllipticCurvePrivateNumbers, EllipticCurvePublicKey, + EllipticCurvePublicNumbers, ) from cryptography.hazmat.primitives.asymmetric.ed448 import ( Ed448PrivateKey, @@ -73,7 +83,7 @@ requires_cryptography = { } -def get_default_algorithms(): +def get_default_algorithms() -> Dict[str, "Algorithm"]: """ Returns the algorithms that are implemented by the library. """ @@ -130,25 +140,29 @@ class Algorithm: ): digest = hashes.Hash(hash_alg(), backend=default_backend()) digest.update(bytestr) - return digest.finalize() + return bytes(digest.finalize()) else: - return hash_alg(bytestr).digest() + return bytes(hash_alg(bytestr).digest()) - def prepare_key(self, key): + # TODO: all key-related `Any`s in this class should optimally be made + # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605 + # that may still be poorly supported. + + def prepare_key(self, key: Any) -> Any: """ Performs necessary validation and conversions on the key and returns the key value in the proper format for sign() and verify(). """ raise NotImplementedError - def sign(self, msg, key): + def sign(self, msg: bytes, key: Any) -> bytes: """ Returns a digital signature for the specified message using the specified key value. """ raise NotImplementedError - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: """ Verifies that the specified digital signature is valid for the specified message and key values. @@ -156,14 +170,14 @@ class Algorithm: raise NotImplementedError @staticmethod - def to_jwk(key_obj): + def to_jwk(key_obj) -> JWKDict: """ Serializes a given RSA key into a JWK """ raise NotImplementedError @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: JWKDict): """ Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object """ @@ -202,7 +216,7 @@ class HMACAlgorithm(Algorithm): SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 - def __init__(self, hash_alg): + def __init__(self, hash_alg) -> None: self.hash_alg = hash_alg def prepare_key(self, key): @@ -242,7 +256,7 @@ class HMACAlgorithm(Algorithm): return base64url_decode(obj["k"]) - def sign(self, msg, key): + def sign(self, msg: bytes, key: bytes) -> bytes: return hmac.new(key, msg, self.hash_alg).digest() def verify(self, msg, key, sig): @@ -261,7 +275,7 @@ if has_crypto: SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 - def __init__(self, hash_alg): + def __init__(self, hash_alg) -> None: self.hash_alg = hash_alg def prepare_key(self, key): @@ -271,16 +285,15 @@ if has_crypto: if not isinstance(key, (bytes, str)): raise TypeError("Expecting a PEM-formatted key.") - key = force_bytes(key) + key_bytes = force_bytes(key) try: - if key.startswith(b"ssh-rsa"): - key = load_ssh_public_key(key) + if key_bytes.startswith(b"ssh-rsa"): + return load_ssh_public_key(key_bytes) else: - key = load_pem_private_key(key, password=None) + return load_pem_private_key(key_bytes, password=None) except ValueError: - key = load_pem_public_key(key) - return key + return load_pem_public_key(key_bytes) @staticmethod def to_jwk(key_obj): @@ -383,12 +396,10 @@ if has_crypto: return numbers.private_key() elif "n" in obj and "e" in obj: # Public key - numbers = RSAPublicNumbers( + return RSAPublicNumbers( from_base64url_uint(obj["e"]), from_base64url_uint(obj["n"]), - ) - - return numbers.public_key() + ).public_key() else: raise InvalidKeyError("Not a public or private key") @@ -412,7 +423,7 @@ if has_crypto: SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 - def __init__(self, hash_alg): + def __init__(self, hash_alg) -> None: self.hash_alg = hash_alg def prepare_key(self, key): @@ -422,18 +433,18 @@ if has_crypto: if not isinstance(key, (bytes, str)): raise TypeError("Expecting a PEM-formatted key.") - key = force_bytes(key) + key_bytes = force_bytes(key) # Attempt to load key. We don't know if it's # a Signing Key or a Verifying Key, so we try # the Verifying Key first. try: - if key.startswith(b"ecdsa-sha2-"): - key = load_ssh_public_key(key) + if key_bytes.startswith(b"ecdsa-sha2-"): + key = load_ssh_public_key(key_bytes) else: - key = load_pem_public_key(key) + key = load_pem_public_key(key_bytes) except ValueError: - key = load_pem_private_key(key, password=None) + key = load_pem_private_key(key_bytes, password=None) # Explicit check the key to prevent confusing errors from cryptography if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): @@ -444,7 +455,7 @@ if has_crypto: return key def sign(self, msg, key): - der_sig = key.sign(msg, ec.ECDSA(self.hash_alg())) + der_sig = key.sign(msg, ECDSA(self.hash_alg())) return der_to_raw_signature(der_sig, key.curve) @@ -457,7 +468,7 @@ if has_crypto: try: if isinstance(key, EllipticCurvePrivateKey): key = key.public_key() - key.verify(der_sig, msg, ec.ECDSA(self.hash_alg())) + key.verify(der_sig, msg, ECDSA(self.hash_alg())) return True except InvalidSignature: return False @@ -472,13 +483,13 @@ if has_crypto: else: raise InvalidKeyError("Not a public or private key") - if isinstance(key_obj.curve, ec.SECP256R1): + if isinstance(key_obj.curve, SECP256R1): crv = "P-256" - elif isinstance(key_obj.curve, ec.SECP384R1): + elif isinstance(key_obj.curve, SECP384R1): crv = "P-384" - elif isinstance(key_obj.curve, ec.SECP521R1): + elif isinstance(key_obj.curve, SECP521R1): crv = "P-521" - elif isinstance(key_obj.curve, ec.SECP256K1): + elif isinstance(key_obj.curve, SECP256K1): crv = "secp256k1" else: raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") @@ -498,7 +509,9 @@ if has_crypto: return json.dumps(obj) @staticmethod - def from_jwk(jwk): + def from_jwk( + jwk: Any, + ) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -519,24 +532,26 @@ if has_crypto: y = base64url_decode(obj.get("y")) curve = obj.get("crv") + curve_obj: EllipticCurve + if curve == "P-256": if len(x) == len(y) == 32: - curve_obj = ec.SECP256R1() + curve_obj = SECP256R1() else: raise InvalidKeyError("Coords should be 32 bytes for curve P-256") elif curve == "P-384": if len(x) == len(y) == 48: - curve_obj = ec.SECP384R1() + curve_obj = SECP384R1() else: raise InvalidKeyError("Coords should be 48 bytes for curve P-384") elif curve == "P-521": if len(x) == len(y) == 66: - curve_obj = ec.SECP521R1() + curve_obj = SECP521R1() else: raise InvalidKeyError("Coords should be 66 bytes for curve P-521") elif curve == "secp256k1": if len(x) == len(y) == 32: - curve_obj = ec.SECP256K1() + curve_obj = SECP256K1() else: raise InvalidKeyError( "Coords should be 32 bytes for curve secp256k1" @@ -544,7 +559,7 @@ if has_crypto: else: raise InvalidKeyError(f"Invalid curve: {curve}") - public_numbers = ec.EllipticCurvePublicNumbers( + public_numbers = EllipticCurvePublicNumbers( x=int.from_bytes(x, byteorder="big"), y=int.from_bytes(y, byteorder="big"), curve=curve_obj, @@ -559,7 +574,7 @@ if has_crypto: "D should be {} bytes for curve {}", len(x), curve ) - return ec.EllipticCurvePrivateNumbers( + return EllipticCurvePrivateNumbers( int.from_bytes(d, byteorder="big"), public_numbers ).private_key() @@ -600,7 +615,7 @@ if has_crypto: This class requires ``cryptography>=2.6`` to be installed. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: pass def prepare_key(self, key): diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index aa3dd32..ef8cce8 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -2,13 +2,15 @@ from __future__ import annotations import json import time +from typing import Any, Optional from .algorithms import get_default_algorithms from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError +from .types import JWKDict class PyJWK: - def __init__(self, jwk_data, algorithm=None): + def __init__(self, jwk_data: JWKDict, algorithm: Optional[str] = None) -> None: self._algorithms = get_default_algorithms() self._jwk_data = jwk_data @@ -55,29 +57,29 @@ class PyJWK: self.key = self.Algorithm.from_jwk(self._jwk_data) @staticmethod - def from_dict(obj, algorithm=None): + def from_dict(obj: JWKDict, algorithm: Optional[str] = None) -> "PyJWK": return PyJWK(obj, algorithm) @staticmethod - def from_json(data, algorithm=None): + def from_json(data: str, algorithm: None = None) -> "PyJWK": obj = json.loads(data) return PyJWK.from_dict(obj, algorithm) @property - def key_type(self): + def key_type(self) -> str: return self._jwk_data.get("kty", None) @property - def key_id(self): + def key_id(self) -> str: return self._jwk_data.get("kid", None) @property - def public_key_use(self): + def public_key_use(self) -> Optional[str]: return self._jwk_data.get("use", None) class PyJWKSet: - def __init__(self, keys: list[dict]) -> None: + def __init__(self, keys: list[JWKDict]) -> None: self.keys = [] if not keys: @@ -97,16 +99,16 @@ class PyJWKSet: raise PyJWKSetError("The JWK Set did not contain any usable keys") @staticmethod - def from_dict(obj): + def from_dict(obj: dict[str, Any]) -> "PyJWKSet": keys = obj.get("keys", []) return PyJWKSet(keys) @staticmethod - def from_json(data): + def from_json(data: str) -> "PyJWKSet": obj = json.loads(data) return PyJWKSet.from_dict(obj) - def __getitem__(self, kid): + def __getitem__(self, kid: str) -> "PyJWK": for key in self.keys: if key.key_id == kid: return key @@ -118,8 +120,8 @@ class PyJWTSetWithTimestamp: self.jwk_set = jwk_set self.timestamp = time.monotonic() - def get_jwk_set(self): + def get_jwk_set(self) -> PyJWKSet: return self.jwk_set - def get_timestamp(self): + def get_timestamp(self) -> float: return self.timestamp diff --git a/jwt/api_jws.py b/jwt/api_jws.py index c914db1..0a775d7 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -3,7 +3,7 @@ from __future__ import annotations import binascii import json import warnings -from typing import Any, Type +from typing import Any, List, Optional, Type from .algorithms import ( Algorithm, @@ -24,7 +24,11 @@ from .warnings import RemovedInPyjwt3Warning class PyJWS: header_typ = "JWT" - def __init__(self, algorithms=None, options=None) -> None: + def __init__( + self, + algorithms: Optional[List[str]] = None, + options: Optional[dict[str, Any]] = None, + ) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( set(algorithms) if algorithms is not None else set(self._algorithms) @@ -164,8 +168,8 @@ class PyJWS: def decode_complete( self, - jwt: str, - key: str = "", + jwt: str | bytes, + key: str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -209,13 +213,13 @@ class PyJWS: def decode( self, - jwt: str, - key: str = "", + jwt: str | bytes, + key: str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, **kwargs, - ) -> str: + ) -> Any: if kwargs: warnings.warn( "passing additional kwargs to decode() is deprecated " @@ -228,7 +232,7 @@ class PyJWS: ) return decoded["payload"] - def get_unverified_header(self, jwt: str | bytes) -> dict: + def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: """Returns back the JWT header parameters as a dict() Note: The signature is not verified so the header parameters @@ -239,7 +243,7 @@ class PyJWS: return headers - def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]: + def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: if isinstance(jwt, str): jwt = jwt.encode("utf-8") @@ -280,13 +284,15 @@ class PyJWS: def _verify_signature( self, signing_input: bytes, - header: dict, + header: dict[str, Any], signature: bytes, - key: str = "", + key: str | bytes = "", algorithms: list[str] | None = None, ) -> None: - - alg = header.get("alg") + try: + alg = header["alg"] + except KeyError: + raise InvalidAlgorithmError("Algorithm not specified") if not alg or (algorithms is not None and alg not in algorithms): raise InvalidAlgorithmError("The specified alg value is not allowed") @@ -304,7 +310,7 @@ class PyJWS: if "kid" in headers: self._validate_kid(headers["kid"]) - def _validate_kid(self, kid: str) -> None: + def _validate_kid(self, kid: Any) -> None: if not isinstance(kid, str): raise InvalidTokenError("Key ID header parameter must be a string") diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 5d21afc..8840708 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -21,10 +21,10 @@ from .warnings import RemovedInPyjwt3Warning class PyJWT: - def __init__(self, options=None): + def __init__(self, options: Optional[dict[str, Any]] = None) -> None: if options is None: options = {} - self.options = {**self._get_default_options(), **options} + self.options: dict[str, Any] = {**self._get_default_options(), **options} @staticmethod def _get_default_options() -> Dict[str, Union[bool, List[str]]]: @@ -96,8 +96,8 @@ class PyJWT: def decode_complete( self, - jwt: str, - key: str = "", + jwt: str | bytes, + key: str | bytes = "", algorithms: Optional[List[str]] = None, options: Optional[Dict[str, Any]] = None, # deprecated arg, remove in pyjwt3 @@ -181,8 +181,8 @@ class PyJWT: def decode( self, - jwt: str, - key: str = "", + jwt: str | bytes, + key: str | bytes = "", algorithms: Optional[List[str]] = None, options: Optional[Dict[str, Any]] = None, # deprecated arg, remove in pyjwt3 @@ -196,7 +196,7 @@ class PyJWT: leeway: Union[int, float, timedelta] = 0, # kwargs **kwargs, - ) -> Dict[str, Any]: + ) -> Any: if kwargs: warnings.warn( "passing additional kwargs to decode() is deprecated " @@ -217,7 +217,14 @@ class PyJWT: ) return decoded["payload"] - def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0): + def _validate_claims( + self, + payload: dict[str, Any], + options: dict[str, Any], + audience=None, + issuer=None, + leeway: float | timedelta = 0, + ) -> None: if isinstance(leeway, timedelta): leeway = leeway.total_seconds() @@ -243,12 +250,21 @@ class PyJWT: if options["verify_aud"]: self._validate_aud(payload, audience) - def _validate_required_claims(self, payload, options): + def _validate_required_claims( + self, + payload: dict[str, Any], + options: dict[str, Any], + ) -> None: for claim in options["require"]: if payload.get(claim) is None: raise MissingRequiredClaimError(claim) - def _validate_iat(self, payload, now, leeway): + def _validate_iat( + self, + payload: dict[str, Any], + now: float, + leeway: float, + ) -> None: iat = payload["iat"] try: int(iat) @@ -257,7 +273,12 @@ class PyJWT: if iat > (now + leeway): raise ImmatureSignatureError("The token is not yet valid (iat)") - def _validate_nbf(self, payload, now, leeway): + def _validate_nbf( + self, + payload: dict[str, Any], + now: float, + leeway: float, + ) -> None: try: nbf = int(payload["nbf"]) except ValueError: @@ -266,7 +287,12 @@ class PyJWT: if nbf > (now + leeway): raise ImmatureSignatureError("The token is not yet valid (nbf)") - def _validate_exp(self, payload, now, leeway): + def _validate_exp( + self, + payload: dict[str, Any], + now: float, + leeway: float, + ) -> None: try: exp = int(payload["exp"]) except ValueError: @@ -275,7 +301,11 @@ class PyJWT: if exp <= (now - leeway): raise ExpiredSignatureError("Signature has expired") - def _validate_aud(self, payload, audience): + def _validate_aud( + self, + payload: dict[str, Any], + audience: Optional[Union[str, Iterable[str]]], + ) -> None: if audience is None: if "aud" not in payload or not payload["aud"]: return @@ -303,7 +333,7 @@ class PyJWT: if all(aud not in audience_claims for aud in audience): raise InvalidAudienceError("Invalid audience") - def _validate_iss(self, payload, issuer): + def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None: if issuer is None: return diff --git a/jwt/exceptions.py b/jwt/exceptions.py index ee201ad..dbe0056 100644 --- a/jwt/exceptions.py +++ b/jwt/exceptions.py @@ -47,10 +47,10 @@ class InvalidAlgorithmError(InvalidTokenError): class MissingRequiredClaimError(InvalidTokenError): - def __init__(self, claim): + def __init__(self, claim: str) -> None: self.claim = claim - def __str__(self): + def __str__(self) -> str: return f'Token is missing the "{self.claim}" claim' diff --git a/jwt/help.py b/jwt/help.py index 0c02eb9..80b0ca5 100644 --- a/jwt/help.py +++ b/jwt/help.py @@ -7,8 +7,10 @@ from . import __version__ as pyjwt_version try: import cryptography + + cryptography_version = cryptography.__version__ except ModuleNotFoundError: - cryptography = None + cryptography_version = "" def info() -> Dict[str, Dict[str, str]]: @@ -29,7 +31,7 @@ def info() -> Dict[str, Dict[str, str]]: if implementation == "CPython": implementation_version = platform.python_version() elif implementation == "PyPy": - pypy_version_info = getattr(sys, "pypy_version_info") + pypy_version_info = sys.pypy_version_info # type: ignore[attr-defined] implementation_version = ( f"{pypy_version_info.major}." f"{pypy_version_info.minor}." @@ -48,7 +50,7 @@ def info() -> Dict[str, Dict[str, str]]: "name": implementation, "version": implementation_version, }, - "cryptography": {"version": getattr(cryptography, "__version__", "")}, + "cryptography": {"version": cryptography_version}, "pyjwt": {"version": pyjwt_version}, } diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index daeb830..aa33bb3 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -1,7 +1,7 @@ import json import urllib.request from functools import lru_cache -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from urllib.error import URLError from .api_jwk import PyJWK, PyJWKSet @@ -18,8 +18,10 @@ class PyJWKClient: max_cached_keys: int = 16, cache_jwk_set: bool = True, lifespan: int = 300, - headers: dict = {}, + headers: Optional[Dict[str, Any]] = None, ): + if headers is None: + headers = {} self.uri = uri self.jwk_set_cache: Optional[JWKSetCache] = None self.headers = headers @@ -62,6 +64,9 @@ class PyJWKClient: if data is None: data = self.fetch_data() + if not isinstance(data, dict): + raise PyJWKClientError("The JWKS endpoint did not return a JSON object") + return PyJWKSet.from_dict(data) def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: diff --git a/jwt/types.py b/jwt/types.py new file mode 100644 index 0000000..830b185 --- /dev/null +++ b/jwt/types.py @@ -0,0 +1,3 @@ +from typing import Any, Dict + +JWKDict = Dict[str, Any] diff --git a/jwt/utils.py b/jwt/utils.py index 16cae06..92e3c8b 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -1,7 +1,7 @@ import base64 import binascii import re -from typing import Union +from typing import Any, AnyStr try: from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve @@ -10,10 +10,10 @@ try: encode_dss_signature, ) except ModuleNotFoundError: - EllipticCurve = None + pass -def force_bytes(value: Union[str, bytes]) -> bytes: +def force_bytes(value: Any) -> bytes: if isinstance(value, str): return value.encode("utf-8") elif isinstance(value, bytes): @@ -22,16 +22,15 @@ def force_bytes(value: Union[str, bytes]) -> bytes: raise TypeError("Expected a string value") -def base64url_decode(input: Union[str, bytes]) -> bytes: - if isinstance(input, str): - input = input.encode("ascii") +def base64url_decode(input: AnyStr) -> bytes: + input_bytes = force_bytes(input) - rem = len(input) % 4 + rem = len(input_bytes) % 4 if rem > 0: - input += b"=" * (4 - rem) + input_bytes += b"=" * (4 - rem) - return base64.urlsafe_b64decode(input) + return base64.urlsafe_b64decode(input_bytes) def base64url_encode(input: bytes) -> bytes: @@ -50,11 +49,8 @@ def to_base64url_uint(val: int) -> bytes: return base64url_encode(int_bytes) -def from_base64url_uint(val: Union[str, bytes]) -> int: - if isinstance(val, str): - val = val.encode("ascii") - - data = base64url_decode(val) +def from_base64url_uint(val: AnyStr) -> int: + data = base64url_decode(force_bytes(val)) return int.from_bytes(data, byteorder="big") @@ -78,7 +74,7 @@ def bytes_from_int(val: int) -> bytes: return val.to_bytes(byte_length, "big", signed=False) -def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes: +def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes: num_bits = curve.key_size num_bytes = (num_bits + 7) // 8 @@ -87,7 +83,7 @@ def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes: return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes) -def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes: +def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes: num_bits = curve.key_size num_bytes = (num_bits + 7) // 8 @@ -97,7 +93,7 @@ def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes: r = bytes_to_number(raw_sig[:num_bytes]) s = bytes_to_number(raw_sig[num_bytes:]) - return encode_dss_signature(r, s) + return bytes(encode_dss_signature(r, s)) # Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252 diff --git a/pyproject.toml b/pyproject.toml index f406505..e7920ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,3 +23,11 @@ python_version = 3.7 ignore_missing_imports = true warn_unused_ignores = true no_implicit_optional = true +strict = true +# TODO: remove these strict loosenings when possible +allow_incomplete_defs = true +allow_untyped_defs = true + +[[tool.mypy.overrides]] +module = "tests.*" +disallow_untyped_calls = false diff --git a/tests/test_advisory.py b/tests/test_advisory.py index ed768d4..8a7e05e 100644 --- a/tests/test_advisory.py +++ b/tests/test_advisory.py @@ -1,6 +1,7 @@ import pytest import jwt +from jwt.algorithms import get_default_algorithms from jwt.exceptions import InvalidKeyError from .utils import crypto_required @@ -50,18 +51,20 @@ class TestAdvisory: # public key is a HMAC secret encoded_bad = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ0ZXN0IjoxMjM0fQ.6ulDpqSlbHmQ8bZXhZRLFko9SwcHrghCwh8d-exJEE4" + algorithm_names = list(get_default_algorithms()) + # Both of the jwt tokens are validated as valid jwt.decode( encoded_good, pub_key_bytes, - algorithms=jwt.algorithms.get_default_algorithms(), + algorithms=algorithm_names, ) with pytest.raises(InvalidKeyError): jwt.decode( encoded_bad, pub_key_bytes, - algorithms=jwt.algorithms.get_default_algorithms(), + algorithms=algorithm_names, ) # Of course the receiver should specify ed25519 algorithm to be used if @@ -100,16 +103,17 @@ class TestAdvisory: # encoded_bad = jwt.encode({"test": 1234}, ssh_key_bytes, algorithm="HS256") encoded_bad = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0ZXN0IjoxMjM0fQ.5eYfbrbeGYmWfypQ6rMWXNZ8bdHcqKng5GPr9MJZITU" + algorithm_names = list(get_default_algorithms()) # Both of the jwt tokens are validated as valid jwt.decode( encoded_good, ssh_key_bytes, - algorithms=jwt.algorithms.get_default_algorithms(), + algorithms=algorithm_names, ) with pytest.raises(InvalidKeyError): jwt.decode( encoded_bad, ssh_key_bytes, - algorithms=jwt.algorithms.get_default_algorithms(), + algorithms=algorithm_names, ) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 894ce28..3ace425 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -25,19 +25,19 @@ class TestAlgorithms: algo = Algorithm() with pytest.raises(NotImplementedError): - algo.sign("message", "key") + algo.sign(b"message", "key") def test_algorithm_should_throw_exception_if_verify_not_impl(self): algo = Algorithm() with pytest.raises(NotImplementedError): - algo.verify("message", "key", "signature") + algo.verify(b"message", "key", b"signature") def test_algorithm_should_throw_exception_if_to_jwk_not_impl(self): algo = Algorithm() with pytest.raises(NotImplementedError): - algo.from_jwk("value") + algo.from_jwk({"val": "ue"}) def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self): algo = Algorithm() diff --git a/tests/test_api_jwk.py b/tests/test_api_jwk.py index 040e81b..3f40fc5 100644 --- a/tests/test_api_jwk.py +++ b/tests/test_api_jwk.py @@ -293,7 +293,7 @@ class TestPyJWKSet: def test_invalid_keys_list(self): with pytest.raises(PyJWKSetError) as err: - PyJWKSet(keys="string") + PyJWKSet(keys="string") # type: ignore assert str(err.value) == "Invalid JWK Set value" def test_empty_keys_list(self): -- cgit v1.2.1