diff options
Diffstat (limited to 'jwt/api_jwt.py')
-rw-r--r-- | jwt/api_jwt.py | 45 |
1 files changed, 24 insertions, 21 deletions
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 06c89f4..da273c5 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Type, Union -from .api_jws import PyJWS +from . import api_jws from .exceptions import ( DecodeError, ExpiredSignatureError, @@ -16,8 +16,11 @@ from .exceptions import ( ) -class PyJWT(PyJWS): - header_type = "JWT" +class PyJWT: + def __init__(self, options=None): + if options is None: + options = {} + self.options = {**self._get_default_options(), **options} @staticmethod def _get_default_options() -> Dict[str, Union[bool, List[str]]]: @@ -33,7 +36,7 @@ class PyJWT(PyJWS): def encode( self, - payload: Union[Dict, bytes], + payload: Dict[str, Any], key: str, algorithm: str = "HS256", headers: Optional[Dict] = None, @@ -59,20 +62,18 @@ class PyJWT(PyJWS): payload, separators=(",", ":"), cls=json_encoder ).encode("utf-8") - return super().encode( + return api_jws.encode( json_payload, key, algorithm, headers, json_encoder ) - def decode( + def decode_complete( self, jwt: str, key: str = "", algorithms: List[str] = None, options: Dict = None, - complete: bool = False, **kwargs, ) -> Dict[str, Any]: - if options is None: options = {"verify_signature": True} else: @@ -83,20 +84,16 @@ class PyJWT(PyJWS): 'It is required that you pass in a value for the "algorithms" argument when calling decode().' ) - decoded = super().decode( + decoded = api_jws.decode_complete( jwt, key=key, algorithms=algorithms, options=options, - complete=complete, **kwargs, ) try: - if complete: - payload = json.loads(decoded["payload"]) - else: - payload = json.loads(decoded) + payload = json.loads(decoded["payload"]) except ValueError as e: raise DecodeError("Invalid payload string: %s" % e) if not isinstance(payload, dict): @@ -106,11 +103,19 @@ class PyJWT(PyJWS): merged_options = {**self.options, **options} self._validate_claims(payload, merged_options, **kwargs) - if complete: - decoded["payload"] = payload - return decoded + decoded["payload"] = payload + return decoded - return payload + def decode( + self, + jwt: str, + key: str = "", + algorithms: List[str] = None, + options: Dict = None, + **kwargs, + ) -> Dict[str, Any]: + decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + return decoded["payload"] def _validate_claims( self, payload, options, audience=None, issuer=None, leeway=0, **kwargs @@ -215,7 +220,5 @@ class PyJWT(PyJWS): _jwt_global_obj = PyJWT() encode = _jwt_global_obj.encode +decode_complete = _jwt_global_obj.decode_complete decode = _jwt_global_obj.decode -register_algorithm = _jwt_global_obj.register_algorithm -unregister_algorithm = _jwt_global_obj.unregister_algorithm -get_unverified_header = _jwt_global_obj.get_unverified_header |