diff options
author | Jon Dufresne <jon.dufresne@gmail.com> | 2020-12-19 15:32:12 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-12-19 18:32:12 -0500 |
commit | 94d102b2e38d09f3c5ec459a8365de8a86c51fe5 (patch) | |
tree | 4b31edc95cf75e0dc79db53866b06af242e13cd5 | |
parent | 2e1e69d4ddfddaba35b6ee99ead1b430654ed661 (diff) | |
download | pyjwt-94d102b2e38d09f3c5ec459a8365de8a86c51fe5.tar.gz |
Split PyJWT/PyJWS classes to tighten type interfaces (#559)
The class PyJWT was previously a subclass of PyJWS. However, this
combination does not follow the Liskov substitution principle. That is,
using PyJWT in place of a PyJWS would not produce correct results or
follow type contracts.
While these classes look to share a common interface it doesn't go
beyond the method names "encode" and "decode" and so is merely
superficial.
The classes have been split into two. PyJWT now uses composition instead
of inheritance to achieve the desired behavior. Splitting the classes in
this way allowed for precising the type interfaces.
The complete parameter to .decode() has been removed. This argument was
used to alter the return type of .decode(). Now, there are two different
methods with more explicit return types and values. The new method name
is .decode_complete(). This fills the previous role filled by
.decode(..., complete=True).
Closes #554, #396, #394
Co-authored-by: Sam Bull <git@sambull.org>
Co-authored-by: Sam Bull <git@sambull.org>
-rw-r--r-- | jwt/__init__.py | 8 | ||||
-rw-r--r-- | jwt/api_jws.py | 32 | ||||
-rw-r--r-- | jwt/api_jwt.py | 45 | ||||
-rw-r--r-- | jwt/jwks_client.py | 6 | ||||
-rw-r--r-- | jwt/utils.py | 2 | ||||
-rw-r--r-- | tests/test_api_jws.py | 21 | ||||
-rw-r--r-- | tests/test_api_jwt.py | 26 |
7 files changed, 94 insertions, 46 deletions
diff --git a/jwt/__init__.py b/jwt/__init__.py index 9d95053..55a5ae3 100644 --- a/jwt/__init__.py +++ b/jwt/__init__.py @@ -1,12 +1,10 @@ -from .api_jws import PyJWS -from .api_jwt import ( - PyJWT, - decode, - encode, +from .api_jws import ( + PyJWS, get_unverified_header, register_algorithm, unregister_algorithm, ) +from .api_jwt import PyJWT, decode, encode from .exceptions import ( DecodeError, ExpiredSignatureError, diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 17c49db..6ccce80 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -1,7 +1,7 @@ import binascii import json from collections.abc import Mapping -from typing import Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type from .algorithms import ( Algorithm, @@ -77,7 +77,7 @@ class PyJWS: def encode( self, - payload: Union[Dict, bytes], + payload: bytes, key: str, algorithm: str = "HS256", headers: Optional[Dict] = None, @@ -127,15 +127,14 @@ class PyJWS: return encoded_string.decode("utf-8") - def decode( + def decode_complete( self, jwt: str, key: str = "", algorithms: List[str] = None, options: Dict = None, - complete: bool = False, **kwargs, - ): + ) -> Dict[str, Any]: if options is None: options = {} merged_options = {**self.options, **options} @@ -153,14 +152,22 @@ class PyJWS: payload, signing_input, header, signature, key, algorithms ) - if complete: - return { - "payload": payload, - "header": header, - "signature": signature, - } + return { + "payload": payload, + "header": header, + "signature": signature, + } - return payload + def decode( + self, + jwt: str, + key: str = "", + algorithms: List[str] = None, + options: Dict = None, + **kwargs, + ) -> str: + decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + return decoded["payload"] def get_unverified_header(self, jwt): """Returns back the JWT header parameters as a dict() @@ -249,6 +256,7 @@ class PyJWS: _jws_global_obj = PyJWS() encode = _jws_global_obj.encode +decode_complete = _jws_global_obj.decode_complete decode = _jws_global_obj.decode register_algorithm = _jws_global_obj.register_algorithm unregister_algorithm = _jws_global_obj.unregister_algorithm diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 06c89f4..da273c5 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Type, Union -from .api_jws import PyJWS +from . import api_jws from .exceptions import ( DecodeError, ExpiredSignatureError, @@ -16,8 +16,11 @@ from .exceptions import ( ) -class PyJWT(PyJWS): - header_type = "JWT" +class PyJWT: + def __init__(self, options=None): + if options is None: + options = {} + self.options = {**self._get_default_options(), **options} @staticmethod def _get_default_options() -> Dict[str, Union[bool, List[str]]]: @@ -33,7 +36,7 @@ class PyJWT(PyJWS): def encode( self, - payload: Union[Dict, bytes], + payload: Dict[str, Any], key: str, algorithm: str = "HS256", headers: Optional[Dict] = None, @@ -59,20 +62,18 @@ class PyJWT(PyJWS): payload, separators=(",", ":"), cls=json_encoder ).encode("utf-8") - return super().encode( + return api_jws.encode( json_payload, key, algorithm, headers, json_encoder ) - def decode( + def decode_complete( self, jwt: str, key: str = "", algorithms: List[str] = None, options: Dict = None, - complete: bool = False, **kwargs, ) -> Dict[str, Any]: - if options is None: options = {"verify_signature": True} else: @@ -83,20 +84,16 @@ class PyJWT(PyJWS): 'It is required that you pass in a value for the "algorithms" argument when calling decode().' ) - decoded = super().decode( + decoded = api_jws.decode_complete( jwt, key=key, algorithms=algorithms, options=options, - complete=complete, **kwargs, ) try: - if complete: - payload = json.loads(decoded["payload"]) - else: - payload = json.loads(decoded) + payload = json.loads(decoded["payload"]) except ValueError as e: raise DecodeError("Invalid payload string: %s" % e) if not isinstance(payload, dict): @@ -106,11 +103,19 @@ class PyJWT(PyJWS): merged_options = {**self.options, **options} self._validate_claims(payload, merged_options, **kwargs) - if complete: - decoded["payload"] = payload - return decoded + decoded["payload"] = payload + return decoded - return payload + def decode( + self, + jwt: str, + key: str = "", + algorithms: List[str] = None, + options: Dict = None, + **kwargs, + ) -> Dict[str, Any]: + decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + return decoded["payload"] def _validate_claims( self, payload, options, audience=None, issuer=None, leeway=0, **kwargs @@ -215,7 +220,5 @@ class PyJWT(PyJWS): _jwt_global_obj = PyJWT() encode = _jwt_global_obj.encode +decode_complete = _jwt_global_obj.decode_complete decode = _jwt_global_obj.decode -register_algorithm = _jwt_global_obj.register_algorithm -unregister_algorithm = _jwt_global_obj.unregister_algorithm -get_unverified_header = _jwt_global_obj.get_unverified_header diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 707724e..ecd10cb 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -2,7 +2,7 @@ import json import urllib.request from .api_jwk import PyJWKSet -from .api_jwt import decode as decode_token +from .api_jwt import decode_complete as decode_token from .exceptions import PyJWKClientError @@ -50,8 +50,6 @@ class PyJWKClient: return signing_key def get_signing_key_from_jwt(self, token): - unverified = decode_token( - token, complete=True, options={"verify_signature": False} - ) + unverified = decode_token(token, options={"verify_signature": False}) header = unverified["header"] return self.get_signing_key(header.get("kid")) diff --git a/jwt/utils.py b/jwt/utils.py index 495c8c7..f9579d3 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -32,7 +32,7 @@ def base64url_decode(input): return base64.urlsafe_b64decode(input) -def base64url_encode(input): +def base64url_encode(input: bytes) -> bytes: return base64.urlsafe_b64encode(input).replace(b"=", b"") diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index a0456dc..2e8dd90 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -215,6 +215,27 @@ class TestJWS: assert decoded_payload == payload + def test_decodes_complete_valid_jws(self, jws, payload): + example_secret = "secret" + example_jws = ( + b"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9." + b"aGVsbG8gd29ybGQ." + b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM" + ) + + decoded = jws.decode_complete( + example_jws, example_secret, algorithms=["HS256"] + ) + + assert decoded == { + "header": {"alg": "HS256", "typ": "JWT"}, + "payload": payload, + "signature": ( + b"\x80E\xb4\xa5\xd58\x93\x13\xed\x86;^\x85\x87a\xc4" + b"\x1ff0\xe1\x9a\x8e\xddq\x08\xa9F\x19p\xc9\xf0\xf3" + ), + } + # 'Control' Elliptic Curve jws created by another library. # Used to test for regressions that could affect both # encoding / decoding operations equally (causing tests diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 35ba6ba..dda896f 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -47,6 +47,27 @@ class TestJWT: assert decoded_payload == example_payload + def test_decodes_complete_valid_jwt(self, jwt): + example_payload = {"hello": "world"} + example_secret = "secret" + example_jwt = ( + b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" + b".eyJoZWxsbyI6ICJ3b3JsZCJ9" + b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8" + ) + decoded = jwt.decode_complete( + example_jwt, example_secret, algorithms=["HS256"] + ) + + assert decoded == { + "header": {"alg": "HS256", "typ": "JWT"}, + "payload": example_payload, + "signature": ( + b'\xb6\xf6\xa0,2\xe8j"J\xc4\xe2\xaa\xa4\x15\xd2' + b"\x10l\xbbI\x84\xa2}\x98c\x9e\xd8&\xf5\xcbi\xca?" + ), + } + def test_load_verify_valid_jwt(self, jwt): example_payload = {"hello": "world"} example_secret = "secret" @@ -313,13 +334,12 @@ class TestJWT: secret = "secret" jwt_message = jwt.encode(payload, secret) - decoded_payload, signing, header, signature = jwt._load(jwt_message) - # With 3 seconds leeway, should be ok for leeway in (3, timedelta(seconds=3)): - jwt.decode( + decoded = jwt.decode( jwt_message, secret, leeway=leeway, algorithms=["HS256"] ) + assert decoded == payload # With 1 seconds, should fail for leeway in (1, timedelta(seconds=1)): |