diff options
-rw-r--r-- | jwt/__init__.py | 8 | ||||
-rw-r--r-- | jwt/api_jws.py | 32 | ||||
-rw-r--r-- | jwt/api_jwt.py | 45 | ||||
-rw-r--r-- | jwt/jwks_client.py | 6 | ||||
-rw-r--r-- | jwt/utils.py | 2 | ||||
-rw-r--r-- | tests/test_api_jws.py | 21 | ||||
-rw-r--r-- | tests/test_api_jwt.py | 26 |
7 files changed, 94 insertions, 46 deletions
diff --git a/jwt/__init__.py b/jwt/__init__.py index 9d95053..55a5ae3 100644 --- a/jwt/__init__.py +++ b/jwt/__init__.py @@ -1,12 +1,10 @@ -from .api_jws import PyJWS -from .api_jwt import ( - PyJWT, - decode, - encode, +from .api_jws import ( + PyJWS, get_unverified_header, register_algorithm, unregister_algorithm, ) +from .api_jwt import PyJWT, decode, encode from .exceptions import ( DecodeError, ExpiredSignatureError, diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 17c49db..6ccce80 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -1,7 +1,7 @@ import binascii import json from collections.abc import Mapping -from typing import Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type from .algorithms import ( Algorithm, @@ -77,7 +77,7 @@ class PyJWS: def encode( self, - payload: Union[Dict, bytes], + payload: bytes, key: str, algorithm: str = "HS256", headers: Optional[Dict] = None, @@ -127,15 +127,14 @@ class PyJWS: return encoded_string.decode("utf-8") - def decode( + def decode_complete( self, jwt: str, key: str = "", algorithms: List[str] = None, options: Dict = None, - complete: bool = False, **kwargs, - ): + ) -> Dict[str, Any]: if options is None: options = {} merged_options = {**self.options, **options} @@ -153,14 +152,22 @@ class PyJWS: payload, signing_input, header, signature, key, algorithms ) - if complete: - return { - "payload": payload, - "header": header, - "signature": signature, - } + return { + "payload": payload, + "header": header, + "signature": signature, + } - return payload + def decode( + self, + jwt: str, + key: str = "", + algorithms: List[str] = None, + options: Dict = None, + **kwargs, + ) -> str: + decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + return decoded["payload"] def get_unverified_header(self, jwt): """Returns back the JWT header parameters as a dict() @@ -249,6 +256,7 @@ class PyJWS: _jws_global_obj = PyJWS() encode = _jws_global_obj.encode +decode_complete = _jws_global_obj.decode_complete decode = _jws_global_obj.decode register_algorithm = _jws_global_obj.register_algorithm unregister_algorithm = _jws_global_obj.unregister_algorithm diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 06c89f4..da273c5 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Type, Union -from .api_jws import PyJWS +from . import api_jws from .exceptions import ( DecodeError, ExpiredSignatureError, @@ -16,8 +16,11 @@ from .exceptions import ( ) -class PyJWT(PyJWS): - header_type = "JWT" +class PyJWT: + def __init__(self, options=None): + if options is None: + options = {} + self.options = {**self._get_default_options(), **options} @staticmethod def _get_default_options() -> Dict[str, Union[bool, List[str]]]: @@ -33,7 +36,7 @@ class PyJWT(PyJWS): def encode( self, - payload: Union[Dict, bytes], + payload: Dict[str, Any], key: str, algorithm: str = "HS256", headers: Optional[Dict] = None, @@ -59,20 +62,18 @@ class PyJWT(PyJWS): payload, separators=(",", ":"), cls=json_encoder ).encode("utf-8") - return super().encode( + return api_jws.encode( json_payload, key, algorithm, headers, json_encoder ) - def decode( + def decode_complete( self, jwt: str, key: str = "", algorithms: List[str] = None, options: Dict = None, - complete: bool = False, **kwargs, ) -> Dict[str, Any]: - if options is None: options = {"verify_signature": True} else: @@ -83,20 +84,16 @@ class PyJWT(PyJWS): 'It is required that you pass in a value for the "algorithms" argument when calling decode().' ) - decoded = super().decode( + decoded = api_jws.decode_complete( jwt, key=key, algorithms=algorithms, options=options, - complete=complete, **kwargs, ) try: - if complete: - payload = json.loads(decoded["payload"]) - else: - payload = json.loads(decoded) + payload = json.loads(decoded["payload"]) except ValueError as e: raise DecodeError("Invalid payload string: %s" % e) if not isinstance(payload, dict): @@ -106,11 +103,19 @@ class PyJWT(PyJWS): merged_options = {**self.options, **options} self._validate_claims(payload, merged_options, **kwargs) - if complete: - decoded["payload"] = payload - return decoded + decoded["payload"] = payload + return decoded - return payload + def decode( + self, + jwt: str, + key: str = "", + algorithms: List[str] = None, + options: Dict = None, + **kwargs, + ) -> Dict[str, Any]: + decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + return decoded["payload"] def _validate_claims( self, payload, options, audience=None, issuer=None, leeway=0, **kwargs @@ -215,7 +220,5 @@ class PyJWT(PyJWS): _jwt_global_obj = PyJWT() encode = _jwt_global_obj.encode +decode_complete = _jwt_global_obj.decode_complete decode = _jwt_global_obj.decode -register_algorithm = _jwt_global_obj.register_algorithm -unregister_algorithm = _jwt_global_obj.unregister_algorithm -get_unverified_header = _jwt_global_obj.get_unverified_header diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 707724e..ecd10cb 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -2,7 +2,7 @@ import json import urllib.request from .api_jwk import PyJWKSet -from .api_jwt import decode as decode_token +from .api_jwt import decode_complete as decode_token from .exceptions import PyJWKClientError @@ -50,8 +50,6 @@ class PyJWKClient: return signing_key def get_signing_key_from_jwt(self, token): - unverified = decode_token( - token, complete=True, options={"verify_signature": False} - ) + unverified = decode_token(token, options={"verify_signature": False}) header = unverified["header"] return self.get_signing_key(header.get("kid")) diff --git a/jwt/utils.py b/jwt/utils.py index 495c8c7..f9579d3 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -32,7 +32,7 @@ def base64url_decode(input): return base64.urlsafe_b64decode(input) -def base64url_encode(input): +def base64url_encode(input: bytes) -> bytes: return base64.urlsafe_b64encode(input).replace(b"=", b"") diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index a0456dc..2e8dd90 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -215,6 +215,27 @@ class TestJWS: assert decoded_payload == payload + def test_decodes_complete_valid_jws(self, jws, payload): + example_secret = "secret" + example_jws = ( + b"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9." + b"aGVsbG8gd29ybGQ." + b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM" + ) + + decoded = jws.decode_complete( + example_jws, example_secret, algorithms=["HS256"] + ) + + assert decoded == { + "header": {"alg": "HS256", "typ": "JWT"}, + "payload": payload, + "signature": ( + b"\x80E\xb4\xa5\xd58\x93\x13\xed\x86;^\x85\x87a\xc4" + b"\x1ff0\xe1\x9a\x8e\xddq\x08\xa9F\x19p\xc9\xf0\xf3" + ), + } + # 'Control' Elliptic Curve jws created by another library. # Used to test for regressions that could affect both # encoding / decoding operations equally (causing tests diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 35ba6ba..dda896f 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -47,6 +47,27 @@ class TestJWT: assert decoded_payload == example_payload + def test_decodes_complete_valid_jwt(self, jwt): + example_payload = {"hello": "world"} + example_secret = "secret" + example_jwt = ( + b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" + b".eyJoZWxsbyI6ICJ3b3JsZCJ9" + b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8" + ) + decoded = jwt.decode_complete( + example_jwt, example_secret, algorithms=["HS256"] + ) + + assert decoded == { + "header": {"alg": "HS256", "typ": "JWT"}, + "payload": example_payload, + "signature": ( + b'\xb6\xf6\xa0,2\xe8j"J\xc4\xe2\xaa\xa4\x15\xd2' + b"\x10l\xbbI\x84\xa2}\x98c\x9e\xd8&\xf5\xcbi\xca?" + ), + } + def test_load_verify_valid_jwt(self, jwt): example_payload = {"hello": "world"} example_secret = "secret" @@ -313,13 +334,12 @@ class TestJWT: secret = "secret" jwt_message = jwt.encode(payload, secret) - decoded_payload, signing, header, signature = jwt._load(jwt_message) - # With 3 seconds leeway, should be ok for leeway in (3, timedelta(seconds=3)): - jwt.decode( + decoded = jwt.decode( jwt_message, secret, leeway=leeway, algorithms=["HS256"] ) + assert decoded == payload # With 1 seconds, should fail for leeway in (1, timedelta(seconds=1)): |