diff options
-rw-r--r-- | docs/conf.py | 4 | ||||
-rw-r--r-- | jwt/algorithms.py | 16 | ||||
-rw-r--r-- | jwt/api_jwk.py | 8 | ||||
-rw-r--r-- | jwt/api_jws.py | 12 | ||||
-rw-r--r-- | jwt/api_jwt.py | 16 | ||||
-rw-r--r-- | jwt/jwks_client.py | 4 | ||||
-rw-r--r-- | pyproject.toml | 5 | ||||
-rw-r--r-- | setup.cfg | 1 | ||||
-rw-r--r-- | tests/test_algorithms.py | 8 | ||||
-rw-r--r-- | tests/test_api_jws.py | 72 | ||||
-rw-r--r-- | tests/test_api_jwt.py | 60 | ||||
-rw-r--r-- | tests/test_jwks_client.py | 4 | ||||
-rw-r--r-- | tests/utils.py | 4 |
13 files changed, 51 insertions, 163 deletions
diff --git a/docs/conf.py b/docs/conf.py index 297e903..c88580d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -20,9 +20,7 @@ def find_version(*file_paths): string inside. """ version_file = read(*file_paths) - version_match = re.search( - r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M - ) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) if version_match: return version_match.group(1) raise RuntimeError("Unable to find version string.") diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 1b0f416..35e43bf 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -444,23 +444,17 @@ if has_crypto: if len(x) == len(y) == 32: curve_obj = ec.SECP256R1() else: - raise InvalidKeyError( - "Coords should be 32 bytes for curve P-256" - ) + raise InvalidKeyError("Coords should be 32 bytes for curve P-256") elif curve == "P-384": if len(x) == len(y) == 48: curve_obj = ec.SECP384R1() else: - raise InvalidKeyError( - "Coords should be 48 bytes for curve P-384" - ) + raise InvalidKeyError("Coords should be 48 bytes for curve P-384") elif curve == "P-521": if len(x) == len(y) == 66: curve_obj = ec.SECP521R1() else: - raise InvalidKeyError( - "Coords should be 66 bytes for curve P-521" - ) + raise InvalidKeyError("Coords should be 66 bytes for curve P-521") else: raise InvalidKeyError(f"Invalid curve: {curve}") @@ -568,8 +562,6 @@ if has_crypto: if isinstance(key, Ed25519PrivateKey): key = key.public_key() key.verify(sig, msg) - return ( - True # If no exception was raised, the signature is valid. - ) + return True # If no exception was raised, the signature is valid. except cryptography.exceptions.InvalidSignature: return False diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index 0966e40..3c10e07 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -13,16 +13,12 @@ class PyJWK: algorithm = self._jwk_data.get("alg", None) if not algorithm: - raise PyJWKError( - "Unable to find a algorithm for key: %s" % self._jwk_data - ) + raise PyJWKError("Unable to find a algorithm for key: %s" % self._jwk_data) self.Algorithm = self._algorithms.get(algorithm) if not self.Algorithm: - raise PyJWKError( - "Unable to find a algorithm for key: %s" % self._jwk_data - ) + raise PyJWKError("Unable to find a algorithm for key: %s" % self._jwk_data) self.key = self.Algorithm.from_jwk(self._jwk_data) diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 152c237..6d13699 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -24,9 +24,7 @@ class PyJWS: def __init__(self, algorithms=None, options=None): self._algorithms = get_default_algorithms() self._valid_algs = ( - set(algorithms) - if algorithms is not None - else set(self._algorithms) + set(algorithms) if algorithms is not None else set(self._algorithms) ) # Remove algorithms that aren't on the whitelist @@ -148,9 +146,7 @@ class PyJWS: payload, signing_input, header, signature = self._load(jwt) if verify_signature: - self._verify_signature( - signing_input, header, signature, key, algorithms - ) + self._verify_signature(signing_input, header, signature, key, algorithms) return { "payload": payload, @@ -230,9 +226,7 @@ class PyJWS: alg = header.get("alg") if algorithms is not None and alg not in algorithms: - raise InvalidAlgorithmError( - "The specified alg value is not allowed" - ) + raise InvalidAlgorithmError("The specified alg value is not allowed") try: alg_obj = self._algorithms[alg] diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index da273c5..80b33bf 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -54,17 +54,13 @@ class PyJWT: for time_claim in ["exp", "iat", "nbf"]: # Convert datetime to a intDate value in known time-format claims if isinstance(payload.get(time_claim), datetime): - payload[time_claim] = timegm( - payload[time_claim].utctimetuple() - ) + payload[time_claim] = timegm(payload[time_claim].utctimetuple()) json_payload = json.dumps( payload, separators=(",", ":"), cls=json_encoder ).encode("utf-8") - return api_jws.encode( - json_payload, key, algorithm, headers, json_encoder - ) + return api_jws.encode(json_payload, key, algorithm, headers, json_encoder) def decode_complete( self, @@ -154,9 +150,7 @@ class PyJWT: try: int(payload["iat"]) except ValueError: - raise InvalidIssuedAtError( - "Issued At claim (iat) must be an integer." - ) + raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.") def _validate_nbf(self, payload, now, leeway): try: @@ -171,9 +165,7 @@ class PyJWT: try: exp = int(payload["exp"]) except ValueError: - raise DecodeError( - "Expiration Time claim (exp) must be an" " integer." - ) + raise DecodeError("Expiration Time claim (exp) must be an" " integer.") if exp < (now - leeway): raise ExpiredSignatureError("Signature has expired") diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index ecd10cb..2e2ff73 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -27,9 +27,7 @@ class PyJWKClient: signing_keys.append(jwk_set_key) if len(signing_keys) == 0: - raise PyJWKClientError( - "The JWKS endpoint did not contain any signing keys" - ) + raise PyJWKClientError("The JWKS endpoint did not contain any signing keys") return signing_keys diff --git a/pyproject.toml b/pyproject.toml index 95cc31b..85e9eca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,12 +15,7 @@ source = ["jwt", ".tox/*/site-packages"] show_missing = true -[tool.black] -line-length = 79 - - [tool.isort] profile = "black" atomic = true combine_as_imports = true -line_length = 79 @@ -64,7 +64,6 @@ exclude = tests.* [flake8] -max-line-length = 79 extend-ignore = E203, E501 [mypy] diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 3009833..3851891 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -238,9 +238,7 @@ class TestAlgorithms: # EC coordinates not equally long with pytest.raises(InvalidKeyError): - algo.from_jwk( - '{"kty": "EC", "x": "dGVzdHRlc3Q=", "y": "dGVzdA=="}' - ) + algo.from_jwk('{"kty": "EC", "x": "dGVzdHRlc3Q=", "y": "dGVzdA=="}') # EC coordinates length invalid for curve in ("P-256", "P-384", "P-521"): @@ -575,9 +573,7 @@ class TestAlgorithmsRFC7520: b"gd2hlcmUgeW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4" ) - signature = base64url_decode( - b"s0h6KThzkfBBBkLspW1h84VsJZFTsPPqMDA7g1Md7p0" - ) + signature = base64url_decode(b"s0h6KThzkfBBBkLspW1h84VsJZFTsPPqMDA7g1Md7p0") algo = HMACAlgorithm(HMACAlgorithm.SHA256) key = algo.prepare_key(load_hmac_key()) diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index bd56aac..01efb28 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -114,11 +114,7 @@ class TestJWS: def test_decode_missing_segments_throws_exception(self, jws): secret = "secret" - example_jws = ( - "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" - ".eyJoZWxsbyI6ICJ3b3JsZCJ9" - "" - ) # Missing segment + example_jws = "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJoZWxsbyI6ICJ3b3JsZCJ9" # Missing segment with pytest.raises(DecodeError) as context: jws.decode(example_jws, secret, algorithms=["HS256"]) @@ -160,9 +156,7 @@ class TestJWS: exception = context.value assert str(exception) == "Invalid header string: must be a json object" - def test_encode_algorithm_param_should_be_case_sensitive( - self, jws, payload - ): + def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload): jws.encode(payload, "secret", algorithm="HS256") @@ -207,9 +201,7 @@ class TestJWS: b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM" ) - decoded_payload = jws.decode( - example_jws, example_secret, algorithms=["HS256"] - ) + decoded_payload = jws.decode(example_jws, example_secret, algorithms=["HS256"]) assert decoded_payload == payload @@ -221,9 +213,7 @@ class TestJWS: b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM" ) - decoded = jws.decode_complete( - example_jws, example_secret, algorithms=["HS256"] - ) + decoded = jws.decode_complete(example_jws, example_secret, algorithms=["HS256"]) assert decoded == { "header": {"alg": "HS256", "typ": "JWT"}, @@ -248,9 +238,7 @@ class TestJWS: b"eyJoZWxsbyI6IndvcmxkIn0.TORyNQab_MoXM7DvNKaTwbrJr4UY" b"d2SsX8hhlnWelQFmPFSf_JzC2EbLnar92t-bXsDovzxp25ExazrVHkfPkQ" ) - decoded_payload = jws.decode( - example_jws, example_pubkey, algorithms=["ES256"] - ) + decoded_payload = jws.decode(example_jws, example_pubkey, algorithms=["ES256"]) json_payload = json.loads(decoded_payload) assert json_payload == example_payload @@ -276,9 +264,7 @@ class TestJWS: b"uwmrtSWCBUjiN8sqJ00CDgycxKqHfUndZbEAOjcCAhBr" b"qWW3mSVivUfubsYbwUdUG3fSRPjaUPcpe8A" ) - decoded_payload = jws.decode( - example_jws, example_pubkey, algorithms=["RS384"] - ) + decoded_payload = jws.decode(example_jws, example_pubkey, algorithms=["RS384"]) json_payload = json.loads(decoded_payload) assert json_payload == example_payload @@ -299,9 +285,7 @@ class TestJWS: def test_allow_skip_verification(self, jws, payload): right_secret = "foo" jws_message = jws.encode(payload, right_secret) - decoded_payload = jws.decode( - jws_message, options={"verify_signature": False} - ) + decoded_payload = jws.decode(jws_message, options={"verify_signature": False}) assert decoded_payload == payload @@ -364,14 +348,8 @@ class TestJWS: assert "Signature verification" in str(exc.value) - def test_verify_signature_with_no_algo_header_throws_exception( - self, jws, payload - ): - example_jws = ( - b"e30" - b".eyJhIjo1fQ" - b".KEh186CjVw_Q8FadjJcaVnE7hO5Z9nHBbU8TgbhHcBY" - ) + def test_verify_signature_with_no_algo_header_throws_exception(self, jws, payload): + example_jws = b"e30.eyJhIjo1fQ.KEh186CjVw_Q8FadjJcaVnE7hO5Z9nHBbU8TgbhHcBY" with pytest.raises(InvalidAlgorithmError): jws.decode(example_jws, "secret", algorithms=["HS256"]) @@ -467,9 +445,7 @@ class TestJWS: with pytest.raises(DecodeError): jws.decode(jws_message, algorithms=["none"]) - def test_decode_with_algo_none_and_verify_false_should_pass( - self, jws, payload - ): + def test_decode_with_algo_none_and_verify_false_should_pass(self, jws, payload): jws_message = jws.encode(payload, key=None, algorithm=None) jws.decode(jws_message, options={"verify_signature": False}) @@ -486,9 +462,7 @@ class TestJWS: assert "kid" in header assert header["kid"] == "toomanysecrets" - def test_get_unverified_header_fails_on_bad_header_types( - self, jws, payload - ): + def test_get_unverified_header_fails_on_bad_header_types(self, jws, payload): # Contains a bad kid value (int 123 instead of string) example_jws = ( "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6MTIzfQ" @@ -505,9 +479,7 @@ class TestJWS: def test_encode_decode_with_rsa_sha256(self, jws, payload): # PEM-formatted RSA key with open(key_path("testkey_rsa.priv"), "rb") as rsa_priv_file: - priv_rsakey = load_pem_private_key( - rsa_priv_file.read(), password=None - ) + priv_rsakey = load_pem_private_key(rsa_priv_file.read(), password=None) jws_message = jws.encode(payload, priv_rsakey, algorithm="RS256") with open(key_path("testkey_rsa.pub"), "rb") as rsa_pub_file: @@ -528,9 +500,7 @@ class TestJWS: def test_encode_decode_with_rsa_sha384(self, jws, payload): # PEM-formatted RSA key with open(key_path("testkey_rsa.priv"), "rb") as rsa_priv_file: - priv_rsakey = load_pem_private_key( - rsa_priv_file.read(), password=None - ) + priv_rsakey = load_pem_private_key(rsa_priv_file.read(), password=None) jws_message = jws.encode(payload, priv_rsakey, algorithm="RS384") with open(key_path("testkey_rsa.pub"), "rb") as rsa_pub_file: @@ -550,9 +520,7 @@ class TestJWS: def test_encode_decode_with_rsa_sha512(self, jws, payload): # PEM-formatted RSA key with open(key_path("testkey_rsa.priv"), "rb") as rsa_priv_file: - priv_rsakey = load_pem_private_key( - rsa_priv_file.read(), password=None - ) + priv_rsakey = load_pem_private_key(rsa_priv_file.read(), password=None) jws_message = jws.encode(payload, priv_rsakey, algorithm="RS512") with open(key_path("testkey_rsa.pub"), "rb") as rsa_pub_file: @@ -592,9 +560,7 @@ class TestJWS: def test_encode_decode_with_ecdsa_sha256(self, jws, payload): # PEM-formatted EC key with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file: - priv_eckey = load_pem_private_key( - ec_priv_file.read(), password=None - ) + priv_eckey = load_pem_private_key(ec_priv_file.read(), password=None) jws_message = jws.encode(payload, priv_eckey, algorithm="ES256") with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file: @@ -615,9 +581,7 @@ class TestJWS: # PEM-formatted EC key with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file: - priv_eckey = load_pem_private_key( - ec_priv_file.read(), password=None - ) + priv_eckey = load_pem_private_key(ec_priv_file.read(), password=None) jws_message = jws.encode(payload, priv_eckey, algorithm="ES384") with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file: @@ -637,9 +601,7 @@ class TestJWS: def test_encode_decode_with_ecdsa_sha512(self, jws, payload): # PEM-formatted EC key with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file: - priv_eckey = load_pem_private_key( - ec_priv_file.read(), password=None - ) + priv_eckey = load_pem_private_key(ec_priv_file.read(), password=None) jws_message = jws.encode(payload, priv_eckey, algorithm="ES512") with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file: diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index e575668..f161763 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -40,9 +40,7 @@ class TestJWT: b".eyJoZWxsbyI6ICJ3b3JsZCJ9" b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8" ) - decoded_payload = jwt.decode( - example_jwt, example_secret, algorithms=["HS256"] - ) + decoded_payload = jwt.decode(example_jwt, example_secret, algorithms=["HS256"]) assert decoded_payload == example_payload @@ -54,9 +52,7 @@ class TestJWT: b".eyJoZWxsbyI6ICJ3b3JsZCJ9" b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8" ) - decoded = jwt.decode_complete( - example_jwt, example_secret, algorithms=["HS256"] - ) + decoded = jwt.decode_complete(example_jwt, example_secret, algorithms=["HS256"]) assert decoded == { "header": {"alg": "HS256", "typ": "JWT"}, @@ -107,9 +103,7 @@ class TestJWT: jwt.decode(example_jwt, secret, algorithms=["HS256"]) exception = context.value - assert ( - str(exception) == "Invalid payload string: must be a json object" - ) + assert str(exception) == "Invalid payload string: must be a json object" def test_decode_with_invalid_audience_param_throws_exception(self, jwt): secret = "secret" @@ -221,15 +215,9 @@ class TestJWT: jwt_message, secret, leeway=1, algorithms=["HS256"] ) - assert decoded_payload["exp"] == timegm( - current_datetime.utctimetuple() - ) - assert decoded_payload["iat"] == timegm( - current_datetime.utctimetuple() - ) - assert decoded_payload["nbf"] == timegm( - current_datetime.utctimetuple() - ) + assert decoded_payload["exp"] == timegm(current_datetime.utctimetuple()) + assert decoded_payload["iat"] == timegm(current_datetime.utctimetuple()) + assert decoded_payload["nbf"] == timegm(current_datetime.utctimetuple()) # payload is not mutated. assert payload == { "exp": current_datetime, @@ -252,9 +240,7 @@ class TestJWT: b"d2SsX8hhlnWelQFmPFSf_JzC2EbLnar92t-bXsDovzxp25ExazrVHkfPkQ" ) - decoded_payload = jwt.decode( - example_jwt, example_pubkey, algorithms=["ES256"] - ) + decoded_payload = jwt.decode(example_jwt, example_pubkey, algorithms=["ES256"]) assert decoded_payload == example_payload # 'Control' RSA JWT created by another library. @@ -278,9 +264,7 @@ class TestJWT: b"uwmrtSWCBUjiN8sqJ00CDgycxKqHfUndZbEAOjcCAhBr" b"qWW3mSVivUfubsYbwUdUG3fSRPjaUPcpe8A" ) - decoded_payload = jwt.decode( - example_jwt, example_pubkey, algorithms=["RS384"] - ) + decoded_payload = jwt.decode(example_jwt, example_pubkey, algorithms=["RS384"]) assert decoded_payload == example_payload @@ -339,9 +323,7 @@ class TestJWT: # With 1 seconds, should fail for leeway in (1, timedelta(seconds=1)): with pytest.raises(ExpiredSignatureError): - jwt.decode( - jwt_message, secret, leeway=leeway, algorithms=["HS256"] - ) + jwt.decode(jwt_message, secret, leeway=leeway, algorithms=["HS256"]) def test_decode_with_notbefore_with_leeway(self, jwt, payload): payload["nbf"] = utc_timestamp() + 10 @@ -397,9 +379,7 @@ class TestJWT: token = jwt.encode(payload, "secret") with pytest.raises(InvalidAudienceError): - jwt.decode( - token, "secret", audience="urn-me", algorithms=["HS256"] - ) + jwt.decode(token, "secret", audience="urn-me", algorithms=["HS256"]) def test_raise_exception_invalid_audience_in_array(self, jwt): payload = { @@ -410,9 +390,7 @@ class TestJWT: token = jwt.encode(payload, "secret") with pytest.raises(InvalidAudienceError): - jwt.decode( - token, "secret", audience="urn:me", algorithms=["HS256"] - ) + jwt.decode(token, "secret", audience="urn:me", algorithms=["HS256"]) def test_raise_exception_token_without_issuer(self, jwt): issuer = "urn:wrong" @@ -431,9 +409,7 @@ class TestJWT: token = jwt.encode(payload, "secret") with pytest.raises(MissingRequiredClaimError) as exc: - jwt.decode( - token, "secret", audience="urn:me", algorithms=["HS256"] - ) + jwt.decode(token, "secret", audience="urn:me", algorithms=["HS256"]) assert exc.value.claim == "aud" @@ -476,9 +452,7 @@ class TestJWT: algorithms=["HS256"], ) - def test_decode_should_raise_error_if_exp_required_but_not_present( - self, jwt - ): + def test_decode_should_raise_error_if_exp_required_but_not_present(self, jwt): payload = { "some": "payload", # exp not present @@ -495,9 +469,7 @@ class TestJWT: assert exc.value.claim == "exp" - def test_decode_should_raise_error_if_iat_required_but_not_present( - self, jwt - ): + def test_decode_should_raise_error_if_iat_required_but_not_present(self, jwt): payload = { "some": "payload", # iat not present @@ -514,9 +486,7 @@ class TestJWT: assert exc.value.claim == "iat" - def test_decode_should_raise_error_if_nbf_required_but_not_present( - self, jwt - ): + def test_decode_should_raise_error_if_nbf_required_but_not_present(self, jwt): payload = { "some": "payload", # nbf not present diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index d607863..7c3dde4 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -73,9 +73,7 @@ class TestPyJWKClient: with pytest.raises(PyJWKClientError) as exc: jwks_client.get_signing_keys() - assert "The JWKS endpoint did not contain any signing keys" in str( - exc.value - ) + assert "The JWKS endpoint did not contain any signing keys" in str(exc.value) def test_get_signing_key(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" diff --git a/tests/utils.py b/tests/utils.py index 6847aba..faa3d5a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,9 +12,7 @@ def utc_timestamp(): def key_path(key_name): - return os.path.join( - os.path.dirname(os.path.realpath(__file__)), "keys", key_name - ) + return os.path.join(os.path.dirname(os.path.realpath(__file__)), "keys", key_name) def no_crypto_required(class_or_func): |