from __future__ import annotations import binascii import json import warnings from typing import Any, List, Optional, Type from .algorithms import ( Algorithm, get_default_algorithms, has_crypto, requires_cryptography, ) from .exceptions import ( DecodeError, InvalidAlgorithmError, InvalidSignatureError, InvalidTokenError, ) from .utils import base64url_decode, base64url_encode from .warnings import RemovedInPyjwt3Warning class PyJWS: header_typ = "JWT" def __init__( self, algorithms: Optional[List[str]] = None, options: Optional[dict[str, Any]] = None, ) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( set(algorithms) if algorithms is not None else set(self._algorithms) ) # Remove algorithms that aren't on the whitelist for key in list(self._algorithms.keys()): if key not in self._valid_algs: del self._algorithms[key] if options is None: options = {} self.options = {**self._get_default_options(), **options} @staticmethod def _get_default_options() -> dict[str, bool]: return {"verify_signature": True} def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: """ Registers a new Algorithm for use when creating and verifying tokens. """ if alg_id in self._algorithms: raise ValueError("Algorithm already has a handler.") if not isinstance(alg_obj, Algorithm): raise TypeError("Object is not of type `Algorithm`") self._algorithms[alg_id] = alg_obj self._valid_algs.add(alg_id) def unregister_algorithm(self, alg_id: str) -> None: """ Unregisters an Algorithm for use when creating and verifying tokens Throws KeyError if algorithm is not registered. """ if alg_id not in self._algorithms: raise KeyError( "The specified algorithm could not be removed" " because it is not registered." ) del self._algorithms[alg_id] self._valid_algs.remove(alg_id) def get_algorithms(self) -> list[str]: """ Returns a list of supported values for the 'alg' parameter. """ return list(self._valid_algs) def get_algorithm_by_name(self, alg_name: str) -> Algorithm: """ For a given string name, return the matching Algorithm object. Example usage: >>> jws_obj.get_algorithm_by_name("RS256") """ try: return self._algorithms[alg_name] except KeyError as e: if not has_crypto and alg_name in requires_cryptography: raise NotImplementedError( f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" ) from e raise NotImplementedError("Algorithm not supported") from e def encode( self, payload: bytes, key: str, algorithm: str | None = "HS256", headers: dict[str, Any] | None = None, json_encoder: Type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, sort_headers: bool = True, ) -> str: segments = [] # declare a new var to narrow the type for type checkers algorithm_: str = algorithm if algorithm is not None else "none" # Prefer headers values if present to function parameters. if headers: headers_alg = headers.get("alg") if headers_alg: algorithm_ = headers["alg"] headers_b64 = headers.get("b64") if headers_b64 is False: is_payload_detached = True # Header header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} if headers: self._validate_headers(headers) header.update(headers) if not header["typ"]: del header["typ"] if is_payload_detached: header["b64"] = False elif "b64" in header: # True is the standard value for b64, so no need for it del header["b64"] json_header = json.dumps( header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers ).encode() segments.append(base64url_encode(json_header)) if is_payload_detached: msg_payload = payload else: msg_payload = base64url_encode(payload) segments.append(msg_payload) # Segments signing_input = b".".join(segments) alg_obj = self.get_algorithm_by_name(algorithm_) key = alg_obj.prepare_key(key) signature = alg_obj.sign(signing_input, key) segments.append(base64url_encode(signature)) # Don't put the payload content inside the encoded token when detached if is_payload_detached: segments[1] = b"" encoded_string = b".".join(segments) return encoded_string.decode("utf-8") def decode_complete( self, jwt: str | bytes, key: str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, **kwargs, ) -> dict[str, Any]: if kwargs: warnings.warn( "passing additional kwargs to decode_complete() is deprecated " "and will be removed in pyjwt version 3. " f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, ) if options is None: options = {} merged_options = {**self.options, **options} verify_signature = merged_options["verify_signature"] if verify_signature and not algorithms: raise DecodeError( 'It is required that you pass in a value for the "algorithms" argument when calling decode().' ) payload, signing_input, header, signature = self._load(jwt) if header.get("b64", True) is False: if detached_payload is None: raise DecodeError( 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' ) payload = detached_payload signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) if verify_signature: self._verify_signature(signing_input, header, signature, key, algorithms) return { "payload": payload, "header": header, "signature": signature, } def decode( self, jwt: str | bytes, key: str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, **kwargs, ) -> Any: if kwargs: warnings.warn( "passing additional kwargs to decode() is deprecated " "and will be removed in pyjwt version 3. " f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, ) decoded = self.decode_complete( jwt, key, algorithms, options, detached_payload=detached_payload ) return decoded["payload"] def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: """Returns back the JWT header parameters as a dict() Note: The signature is not verified so the header parameters should not be fully trusted until signature verification is complete """ headers = self._load(jwt)[2] self._validate_headers(headers) return headers def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: if isinstance(jwt, str): jwt = jwt.encode("utf-8") if not isinstance(jwt, bytes): raise DecodeError(f"Invalid token type. Token must be a {bytes}") try: signing_input, crypto_segment = jwt.rsplit(b".", 1) header_segment, payload_segment = signing_input.split(b".", 1) except ValueError as err: raise DecodeError("Not enough segments") from err try: header_data = base64url_decode(header_segment) except (TypeError, binascii.Error) as err: raise DecodeError("Invalid header padding") from err try: header = json.loads(header_data) except ValueError as e: raise DecodeError(f"Invalid header string: {e}") from e if not isinstance(header, dict): raise DecodeError("Invalid header string: must be a json object") try: payload = base64url_decode(payload_segment) except (TypeError, binascii.Error) as err: raise DecodeError("Invalid payload padding") from err try: signature = base64url_decode(crypto_segment) except (TypeError, binascii.Error) as err: raise DecodeError("Invalid crypto padding") from err return (payload, signing_input, header, signature) def _verify_signature( self, signing_input: bytes, header: dict[str, Any], signature: bytes, key: str | bytes = "", algorithms: list[str] | None = None, ) -> None: try: alg = header["alg"] except KeyError: raise InvalidAlgorithmError("Algorithm not specified") if not alg or (algorithms is not None and alg not in algorithms): raise InvalidAlgorithmError("The specified alg value is not allowed") try: alg_obj = self.get_algorithm_by_name(alg) except NotImplementedError as e: raise InvalidAlgorithmError("Algorithm not supported") from e key = alg_obj.prepare_key(key) if not alg_obj.verify(signing_input, key, signature): raise InvalidSignatureError("Signature verification failed") def _validate_headers(self, headers: dict[str, Any]) -> None: if "kid" in headers: self._validate_kid(headers["kid"]) def _validate_kid(self, kid: Any) -> None: if not isinstance(kid, str): raise InvalidTokenError("Key ID header parameter must be a string") _jws_global_obj = PyJWS() encode = _jws_global_obj.encode decode_complete = _jws_global_obj.decode_complete decode = _jws_global_obj.decode register_algorithm = _jws_global_obj.register_algorithm unregister_algorithm = _jws_global_obj.unregister_algorithm get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name get_unverified_header = _jws_global_obj.get_unverified_header