summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJon Dufresne <jon.dufresne@gmail.com>2020-12-19 15:32:12 -0800
committerGitHub <noreply@github.com>2020-12-19 18:32:12 -0500
commit94d102b2e38d09f3c5ec459a8365de8a86c51fe5 (patch)
tree4b31edc95cf75e0dc79db53866b06af242e13cd5
parent2e1e69d4ddfddaba35b6ee99ead1b430654ed661 (diff)
downloadpyjwt-94d102b2e38d09f3c5ec459a8365de8a86c51fe5.tar.gz
Split PyJWT/PyJWS classes to tighten type interfaces (#559)
The class PyJWT was previously a subclass of PyJWS. However, this combination does not follow the Liskov substitution principle. That is, using PyJWT in place of a PyJWS would not produce correct results or follow type contracts. While these classes look to share a common interface it doesn't go beyond the method names "encode" and "decode" and so is merely superficial. The classes have been split into two. PyJWT now uses composition instead of inheritance to achieve the desired behavior. Splitting the classes in this way allowed for precising the type interfaces. The complete parameter to .decode() has been removed. This argument was used to alter the return type of .decode(). Now, there are two different methods with more explicit return types and values. The new method name is .decode_complete(). This fills the previous role filled by .decode(..., complete=True). Closes #554, #396, #394 Co-authored-by: Sam Bull <git@sambull.org> Co-authored-by: Sam Bull <git@sambull.org>
-rw-r--r--jwt/__init__.py8
-rw-r--r--jwt/api_jws.py32
-rw-r--r--jwt/api_jwt.py45
-rw-r--r--jwt/jwks_client.py6
-rw-r--r--jwt/utils.py2
-rw-r--r--tests/test_api_jws.py21
-rw-r--r--tests/test_api_jwt.py26
7 files changed, 94 insertions, 46 deletions
diff --git a/jwt/__init__.py b/jwt/__init__.py
index 9d95053..55a5ae3 100644
--- a/jwt/__init__.py
+++ b/jwt/__init__.py
@@ -1,12 +1,10 @@
-from .api_jws import PyJWS
-from .api_jwt import (
- PyJWT,
- decode,
- encode,
+from .api_jws import (
+ PyJWS,
get_unverified_header,
register_algorithm,
unregister_algorithm,
)
+from .api_jwt import PyJWT, decode, encode
from .exceptions import (
DecodeError,
ExpiredSignatureError,
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index 17c49db..6ccce80 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -1,7 +1,7 @@
import binascii
import json
from collections.abc import Mapping
-from typing import Dict, List, Optional, Type, Union
+from typing import Any, Dict, List, Optional, Type
from .algorithms import (
Algorithm,
@@ -77,7 +77,7 @@ class PyJWS:
def encode(
self,
- payload: Union[Dict, bytes],
+ payload: bytes,
key: str,
algorithm: str = "HS256",
headers: Optional[Dict] = None,
@@ -127,15 +127,14 @@ class PyJWS:
return encoded_string.decode("utf-8")
- def decode(
+ def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: List[str] = None,
options: Dict = None,
- complete: bool = False,
**kwargs,
- ):
+ ) -> Dict[str, Any]:
if options is None:
options = {}
merged_options = {**self.options, **options}
@@ -153,14 +152,22 @@ class PyJWS:
payload, signing_input, header, signature, key, algorithms
)
- if complete:
- return {
- "payload": payload,
- "header": header,
- "signature": signature,
- }
+ return {
+ "payload": payload,
+ "header": header,
+ "signature": signature,
+ }
- return payload
+ def decode(
+ self,
+ jwt: str,
+ key: str = "",
+ algorithms: List[str] = None,
+ options: Dict = None,
+ **kwargs,
+ ) -> str:
+ decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
+ return decoded["payload"]
def get_unverified_header(self, jwt):
"""Returns back the JWT header parameters as a dict()
@@ -249,6 +256,7 @@ class PyJWS:
_jws_global_obj = PyJWS()
encode = _jws_global_obj.encode
+decode_complete = _jws_global_obj.decode_complete
decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 06c89f4..da273c5 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Type, Union
-from .api_jws import PyJWS
+from . import api_jws
from .exceptions import (
DecodeError,
ExpiredSignatureError,
@@ -16,8 +16,11 @@ from .exceptions import (
)
-class PyJWT(PyJWS):
- header_type = "JWT"
+class PyJWT:
+ def __init__(self, options=None):
+ if options is None:
+ options = {}
+ self.options = {**self._get_default_options(), **options}
@staticmethod
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
@@ -33,7 +36,7 @@ class PyJWT(PyJWS):
def encode(
self,
- payload: Union[Dict, bytes],
+ payload: Dict[str, Any],
key: str,
algorithm: str = "HS256",
headers: Optional[Dict] = None,
@@ -59,20 +62,18 @@ class PyJWT(PyJWS):
payload, separators=(",", ":"), cls=json_encoder
).encode("utf-8")
- return super().encode(
+ return api_jws.encode(
json_payload, key, algorithm, headers, json_encoder
)
- def decode(
+ def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: List[str] = None,
options: Dict = None,
- complete: bool = False,
**kwargs,
) -> Dict[str, Any]:
-
if options is None:
options = {"verify_signature": True}
else:
@@ -83,20 +84,16 @@ class PyJWT(PyJWS):
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
)
- decoded = super().decode(
+ decoded = api_jws.decode_complete(
jwt,
key=key,
algorithms=algorithms,
options=options,
- complete=complete,
**kwargs,
)
try:
- if complete:
- payload = json.loads(decoded["payload"])
- else:
- payload = json.loads(decoded)
+ payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError("Invalid payload string: %s" % e)
if not isinstance(payload, dict):
@@ -106,11 +103,19 @@ class PyJWT(PyJWS):
merged_options = {**self.options, **options}
self._validate_claims(payload, merged_options, **kwargs)
- if complete:
- decoded["payload"] = payload
- return decoded
+ decoded["payload"] = payload
+ return decoded
- return payload
+ def decode(
+ self,
+ jwt: str,
+ key: str = "",
+ algorithms: List[str] = None,
+ options: Dict = None,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
+ return decoded["payload"]
def _validate_claims(
self, payload, options, audience=None, issuer=None, leeway=0, **kwargs
@@ -215,7 +220,5 @@ class PyJWT(PyJWS):
_jwt_global_obj = PyJWT()
encode = _jwt_global_obj.encode
+decode_complete = _jwt_global_obj.decode_complete
decode = _jwt_global_obj.decode
-register_algorithm = _jwt_global_obj.register_algorithm
-unregister_algorithm = _jwt_global_obj.unregister_algorithm
-get_unverified_header = _jwt_global_obj.get_unverified_header
diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py
index 707724e..ecd10cb 100644
--- a/jwt/jwks_client.py
+++ b/jwt/jwks_client.py
@@ -2,7 +2,7 @@ import json
import urllib.request
from .api_jwk import PyJWKSet
-from .api_jwt import decode as decode_token
+from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientError
@@ -50,8 +50,6 @@ class PyJWKClient:
return signing_key
def get_signing_key_from_jwt(self, token):
- unverified = decode_token(
- token, complete=True, options={"verify_signature": False}
- )
+ unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"]
return self.get_signing_key(header.get("kid"))
diff --git a/jwt/utils.py b/jwt/utils.py
index 495c8c7..f9579d3 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -32,7 +32,7 @@ def base64url_decode(input):
return base64.urlsafe_b64decode(input)
-def base64url_encode(input):
+def base64url_encode(input: bytes) -> bytes:
return base64.urlsafe_b64encode(input).replace(b"=", b"")
diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py
index a0456dc..2e8dd90 100644
--- a/tests/test_api_jws.py
+++ b/tests/test_api_jws.py
@@ -215,6 +215,27 @@ class TestJWS:
assert decoded_payload == payload
+ def test_decodes_complete_valid_jws(self, jws, payload):
+ example_secret = "secret"
+ example_jws = (
+ b"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9."
+ b"aGVsbG8gd29ybGQ."
+ b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM"
+ )
+
+ decoded = jws.decode_complete(
+ example_jws, example_secret, algorithms=["HS256"]
+ )
+
+ assert decoded == {
+ "header": {"alg": "HS256", "typ": "JWT"},
+ "payload": payload,
+ "signature": (
+ b"\x80E\xb4\xa5\xd58\x93\x13\xed\x86;^\x85\x87a\xc4"
+ b"\x1ff0\xe1\x9a\x8e\xddq\x08\xa9F\x19p\xc9\xf0\xf3"
+ ),
+ }
+
# 'Control' Elliptic Curve jws created by another library.
# Used to test for regressions that could affect both
# encoding / decoding operations equally (causing tests
diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py
index 35ba6ba..dda896f 100644
--- a/tests/test_api_jwt.py
+++ b/tests/test_api_jwt.py
@@ -47,6 +47,27 @@ class TestJWT:
assert decoded_payload == example_payload
+ def test_decodes_complete_valid_jwt(self, jwt):
+ example_payload = {"hello": "world"}
+ example_secret = "secret"
+ example_jwt = (
+ b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
+ b".eyJoZWxsbyI6ICJ3b3JsZCJ9"
+ b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8"
+ )
+ decoded = jwt.decode_complete(
+ example_jwt, example_secret, algorithms=["HS256"]
+ )
+
+ assert decoded == {
+ "header": {"alg": "HS256", "typ": "JWT"},
+ "payload": example_payload,
+ "signature": (
+ b'\xb6\xf6\xa0,2\xe8j"J\xc4\xe2\xaa\xa4\x15\xd2'
+ b"\x10l\xbbI\x84\xa2}\x98c\x9e\xd8&\xf5\xcbi\xca?"
+ ),
+ }
+
def test_load_verify_valid_jwt(self, jwt):
example_payload = {"hello": "world"}
example_secret = "secret"
@@ -313,13 +334,12 @@ class TestJWT:
secret = "secret"
jwt_message = jwt.encode(payload, secret)
- decoded_payload, signing, header, signature = jwt._load(jwt_message)
-
# With 3 seconds leeway, should be ok
for leeway in (3, timedelta(seconds=3)):
- jwt.decode(
+ decoded = jwt.decode(
jwt_message, secret, leeway=leeway, algorithms=["HS256"]
)
+ assert decoded == payload
# With 1 seconds, should fail
for leeway in (1, timedelta(seconds=1)):