summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--jwt/__init__.py8
-rw-r--r--jwt/api_jws.py32
-rw-r--r--jwt/api_jwt.py45
-rw-r--r--jwt/jwks_client.py6
-rw-r--r--jwt/utils.py2
-rw-r--r--tests/test_api_jws.py21
-rw-r--r--tests/test_api_jwt.py26
7 files changed, 94 insertions, 46 deletions
diff --git a/jwt/__init__.py b/jwt/__init__.py
index 9d95053..55a5ae3 100644
--- a/jwt/__init__.py
+++ b/jwt/__init__.py
@@ -1,12 +1,10 @@
-from .api_jws import PyJWS
-from .api_jwt import (
- PyJWT,
- decode,
- encode,
+from .api_jws import (
+ PyJWS,
get_unverified_header,
register_algorithm,
unregister_algorithm,
)
+from .api_jwt import PyJWT, decode, encode
from .exceptions import (
DecodeError,
ExpiredSignatureError,
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index 17c49db..6ccce80 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -1,7 +1,7 @@
import binascii
import json
from collections.abc import Mapping
-from typing import Dict, List, Optional, Type, Union
+from typing import Any, Dict, List, Optional, Type
from .algorithms import (
Algorithm,
@@ -77,7 +77,7 @@ class PyJWS:
def encode(
self,
- payload: Union[Dict, bytes],
+ payload: bytes,
key: str,
algorithm: str = "HS256",
headers: Optional[Dict] = None,
@@ -127,15 +127,14 @@ class PyJWS:
return encoded_string.decode("utf-8")
- def decode(
+ def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: List[str] = None,
options: Dict = None,
- complete: bool = False,
**kwargs,
- ):
+ ) -> Dict[str, Any]:
if options is None:
options = {}
merged_options = {**self.options, **options}
@@ -153,14 +152,22 @@ class PyJWS:
payload, signing_input, header, signature, key, algorithms
)
- if complete:
- return {
- "payload": payload,
- "header": header,
- "signature": signature,
- }
+ return {
+ "payload": payload,
+ "header": header,
+ "signature": signature,
+ }
- return payload
+ def decode(
+ self,
+ jwt: str,
+ key: str = "",
+ algorithms: List[str] = None,
+ options: Dict = None,
+ **kwargs,
+ ) -> str:
+ decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
+ return decoded["payload"]
def get_unverified_header(self, jwt):
"""Returns back the JWT header parameters as a dict()
@@ -249,6 +256,7 @@ class PyJWS:
_jws_global_obj = PyJWS()
encode = _jws_global_obj.encode
+decode_complete = _jws_global_obj.decode_complete
decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 06c89f4..da273c5 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Type, Union
-from .api_jws import PyJWS
+from . import api_jws
from .exceptions import (
DecodeError,
ExpiredSignatureError,
@@ -16,8 +16,11 @@ from .exceptions import (
)
-class PyJWT(PyJWS):
- header_type = "JWT"
+class PyJWT:
+ def __init__(self, options=None):
+ if options is None:
+ options = {}
+ self.options = {**self._get_default_options(), **options}
@staticmethod
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
@@ -33,7 +36,7 @@ class PyJWT(PyJWS):
def encode(
self,
- payload: Union[Dict, bytes],
+ payload: Dict[str, Any],
key: str,
algorithm: str = "HS256",
headers: Optional[Dict] = None,
@@ -59,20 +62,18 @@ class PyJWT(PyJWS):
payload, separators=(",", ":"), cls=json_encoder
).encode("utf-8")
- return super().encode(
+ return api_jws.encode(
json_payload, key, algorithm, headers, json_encoder
)
- def decode(
+ def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: List[str] = None,
options: Dict = None,
- complete: bool = False,
**kwargs,
) -> Dict[str, Any]:
-
if options is None:
options = {"verify_signature": True}
else:
@@ -83,20 +84,16 @@ class PyJWT(PyJWS):
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
)
- decoded = super().decode(
+ decoded = api_jws.decode_complete(
jwt,
key=key,
algorithms=algorithms,
options=options,
- complete=complete,
**kwargs,
)
try:
- if complete:
- payload = json.loads(decoded["payload"])
- else:
- payload = json.loads(decoded)
+ payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError("Invalid payload string: %s" % e)
if not isinstance(payload, dict):
@@ -106,11 +103,19 @@ class PyJWT(PyJWS):
merged_options = {**self.options, **options}
self._validate_claims(payload, merged_options, **kwargs)
- if complete:
- decoded["payload"] = payload
- return decoded
+ decoded["payload"] = payload
+ return decoded
- return payload
+ def decode(
+ self,
+ jwt: str,
+ key: str = "",
+ algorithms: List[str] = None,
+ options: Dict = None,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
+ return decoded["payload"]
def _validate_claims(
self, payload, options, audience=None, issuer=None, leeway=0, **kwargs
@@ -215,7 +220,5 @@ class PyJWT(PyJWS):
_jwt_global_obj = PyJWT()
encode = _jwt_global_obj.encode
+decode_complete = _jwt_global_obj.decode_complete
decode = _jwt_global_obj.decode
-register_algorithm = _jwt_global_obj.register_algorithm
-unregister_algorithm = _jwt_global_obj.unregister_algorithm
-get_unverified_header = _jwt_global_obj.get_unverified_header
diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py
index 707724e..ecd10cb 100644
--- a/jwt/jwks_client.py
+++ b/jwt/jwks_client.py
@@ -2,7 +2,7 @@ import json
import urllib.request
from .api_jwk import PyJWKSet
-from .api_jwt import decode as decode_token
+from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientError
@@ -50,8 +50,6 @@ class PyJWKClient:
return signing_key
def get_signing_key_from_jwt(self, token):
- unverified = decode_token(
- token, complete=True, options={"verify_signature": False}
- )
+ unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"]
return self.get_signing_key(header.get("kid"))
diff --git a/jwt/utils.py b/jwt/utils.py
index 495c8c7..f9579d3 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -32,7 +32,7 @@ def base64url_decode(input):
return base64.urlsafe_b64decode(input)
-def base64url_encode(input):
+def base64url_encode(input: bytes) -> bytes:
return base64.urlsafe_b64encode(input).replace(b"=", b"")
diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py
index a0456dc..2e8dd90 100644
--- a/tests/test_api_jws.py
+++ b/tests/test_api_jws.py
@@ -215,6 +215,27 @@ class TestJWS:
assert decoded_payload == payload
+ def test_decodes_complete_valid_jws(self, jws, payload):
+ example_secret = "secret"
+ example_jws = (
+ b"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9."
+ b"aGVsbG8gd29ybGQ."
+ b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM"
+ )
+
+ decoded = jws.decode_complete(
+ example_jws, example_secret, algorithms=["HS256"]
+ )
+
+ assert decoded == {
+ "header": {"alg": "HS256", "typ": "JWT"},
+ "payload": payload,
+ "signature": (
+ b"\x80E\xb4\xa5\xd58\x93\x13\xed\x86;^\x85\x87a\xc4"
+ b"\x1ff0\xe1\x9a\x8e\xddq\x08\xa9F\x19p\xc9\xf0\xf3"
+ ),
+ }
+
# 'Control' Elliptic Curve jws created by another library.
# Used to test for regressions that could affect both
# encoding / decoding operations equally (causing tests
diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py
index 35ba6ba..dda896f 100644
--- a/tests/test_api_jwt.py
+++ b/tests/test_api_jwt.py
@@ -47,6 +47,27 @@ class TestJWT:
assert decoded_payload == example_payload
+ def test_decodes_complete_valid_jwt(self, jwt):
+ example_payload = {"hello": "world"}
+ example_secret = "secret"
+ example_jwt = (
+ b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
+ b".eyJoZWxsbyI6ICJ3b3JsZCJ9"
+ b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8"
+ )
+ decoded = jwt.decode_complete(
+ example_jwt, example_secret, algorithms=["HS256"]
+ )
+
+ assert decoded == {
+ "header": {"alg": "HS256", "typ": "JWT"},
+ "payload": example_payload,
+ "signature": (
+ b'\xb6\xf6\xa0,2\xe8j"J\xc4\xe2\xaa\xa4\x15\xd2'
+ b"\x10l\xbbI\x84\xa2}\x98c\x9e\xd8&\xf5\xcbi\xca?"
+ ),
+ }
+
def test_load_verify_valid_jwt(self, jwt):
example_payload = {"hello": "world"}
example_secret = "secret"
@@ -313,13 +334,12 @@ class TestJWT:
secret = "secret"
jwt_message = jwt.encode(payload, secret)
- decoded_payload, signing, header, signature = jwt._load(jwt_message)
-
# With 3 seconds leeway, should be ok
for leeway in (3, timedelta(seconds=3)):
- jwt.decode(
+ decoded = jwt.decode(
jwt_message, secret, leeway=leeway, algorithms=["HS256"]
)
+ assert decoded == payload
# With 1 seconds, should fail
for leeway in (1, timedelta(seconds=1)):