summaryrefslogtreecommitdiff
path: root/jwt/api_jwt.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/api_jwt.py')
-rw-r--r--jwt/api_jwt.py45
1 files changed, 24 insertions, 21 deletions
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 06c89f4..da273c5 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Type, Union
-from .api_jws import PyJWS
+from . import api_jws
from .exceptions import (
DecodeError,
ExpiredSignatureError,
@@ -16,8 +16,11 @@ from .exceptions import (
)
-class PyJWT(PyJWS):
- header_type = "JWT"
+class PyJWT:
+ def __init__(self, options=None):
+ if options is None:
+ options = {}
+ self.options = {**self._get_default_options(), **options}
@staticmethod
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
@@ -33,7 +36,7 @@ class PyJWT(PyJWS):
def encode(
self,
- payload: Union[Dict, bytes],
+ payload: Dict[str, Any],
key: str,
algorithm: str = "HS256",
headers: Optional[Dict] = None,
@@ -59,20 +62,18 @@ class PyJWT(PyJWS):
payload, separators=(",", ":"), cls=json_encoder
).encode("utf-8")
- return super().encode(
+ return api_jws.encode(
json_payload, key, algorithm, headers, json_encoder
)
- def decode(
+ def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: List[str] = None,
options: Dict = None,
- complete: bool = False,
**kwargs,
) -> Dict[str, Any]:
-
if options is None:
options = {"verify_signature": True}
else:
@@ -83,20 +84,16 @@ class PyJWT(PyJWS):
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
)
- decoded = super().decode(
+ decoded = api_jws.decode_complete(
jwt,
key=key,
algorithms=algorithms,
options=options,
- complete=complete,
**kwargs,
)
try:
- if complete:
- payload = json.loads(decoded["payload"])
- else:
- payload = json.loads(decoded)
+ payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError("Invalid payload string: %s" % e)
if not isinstance(payload, dict):
@@ -106,11 +103,19 @@ class PyJWT(PyJWS):
merged_options = {**self.options, **options}
self._validate_claims(payload, merged_options, **kwargs)
- if complete:
- decoded["payload"] = payload
- return decoded
+ decoded["payload"] = payload
+ return decoded
- return payload
+ def decode(
+ self,
+ jwt: str,
+ key: str = "",
+ algorithms: List[str] = None,
+ options: Dict = None,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
+ return decoded["payload"]
def _validate_claims(
self, payload, options, audience=None, issuer=None, leeway=0, **kwargs
@@ -215,7 +220,5 @@ class PyJWT(PyJWS):
_jwt_global_obj = PyJWT()
encode = _jwt_global_obj.encode
+decode_complete = _jwt_global_obj.decode_complete
decode = _jwt_global_obj.decode
-register_algorithm = _jwt_global_obj.register_algorithm
-unregister_algorithm = _jwt_global_obj.unregister_algorithm
-get_unverified_header = _jwt_global_obj.get_unverified_header