diff options
author | José Padilla <jpadilla@webapplicate.com> | 2020-08-24 06:12:23 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-24 16:12:23 +0600 |
commit | dee2d31640940bfc77c3bc5839587a082eae58a4 (patch) | |
tree | 5f37fa1dbe2ee5c9485a8f904cb146fa670199d3 | |
parent | 5d33b041f2aeccc653436958c08187f1bd538d94 (diff) | |
download | pyjwt-dee2d31640940bfc77c3bc5839587a082eae58a4.tar.gz |
Introduce better experience for JWKs (#511)
* Introduce better experience for JWKs
* Remove explicit inheritance
* Add tests for PyJWK
* Fix failing test
* Get rid of lambda
-rw-r--r-- | jwt/__init__.py | 66 | ||||
-rw-r--r-- | jwt/algorithms.py | 7 | ||||
-rw-r--r-- | jwt/api_jwk.py | 72 | ||||
-rw-r--r-- | jwt/api_jws.py | 8 | ||||
-rw-r--r-- | jwt/api_jwt.py | 17 | ||||
-rw-r--r-- | jwt/exceptions.py | 12 | ||||
-rw-r--r-- | jwt/jwks_client.py | 64 | ||||
-rw-r--r-- | pyproject.toml | 2 | ||||
-rw-r--r-- | setup.cfg | 4 | ||||
-rw-r--r-- | tests/test_api_jwk.py | 116 | ||||
-rw-r--r-- | tests/test_jwks_client.py | 113 |
11 files changed, 459 insertions, 22 deletions
diff --git a/jwt/__init__.py b/jwt/__init__.py index 40bfdb0..5695848 100644 --- a/jwt/__init__.py +++ b/jwt/__init__.py @@ -1,20 +1,3 @@ -# flake8: noqa - -""" -JSON Web Token implementation - -Minimum implementation based on this spec: -https://self-issued.info/docs/draft-jones-json-web-token-01.html -""" - - -__title__ = "pyjwt" -__version__ = "1.7.1" -__author__ = "José Padilla" -__license__ = "MIT" -__copyright__ = "Copyright 2015-2018 José Padilla" - - from .api_jws import PyJWS from .api_jwt import ( PyJWT, @@ -35,9 +18,56 @@ from .exceptions import ( InvalidIssuedAtError, InvalidIssuer, InvalidIssuerError, - InvalidKeyError, InvalidSignatureError, InvalidTokenError, MissingRequiredClaimError, + PyJWKClientError, + PyJWKError, + PyJWKSetError, PyJWTError, ) +from .jwks_client import PyJWKClient + +__version__ = "2.0.0.dev" + +__title__ = "PyJWT" +__description__ = "JSON Web Token implementation in Python" +__url__ = "https://pyjwt.readthedocs.io" +__uri__ = __url__ +__doc__ = __description__ + " <" + __uri__ + ">" + +__author__ = "José Padilla" +__email__ = "hello@jpadilla.com" + +__license__ = "MIT" +__copyright__ = "Copyright 2015-2020 José Padilla" + + +__all__ = [ + "PyJWS", + "PyJWT", + "PyJWKClient", + "decode", + "encode", + "get_unverified_header", + "register_algorithm", + "unregister_algorithm", + # Exceptions + "DecodeError", + "ExpiredSignature", + "ExpiredSignatureError", + "ImmatureSignatureError", + "InvalidAlgorithmError", + "InvalidAudience", + "InvalidAudienceError", + "InvalidIssuedAtError", + "InvalidIssuer", + "InvalidIssuerError", + "InvalidSignatureError", + "InvalidTokenError", + "MissingRequiredClaimError", + "PyJWKClientError", + "PyJWKError", + "PyJWKSetError", + "PyJWTError", +] diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 3a7040c..be41ff1 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -299,7 +299,12 @@ if has_crypto: # noqa: C901 @staticmethod def from_jwk(jwk): try: - obj = json.loads(jwk) + if isinstance(jwk, str): + obj = json.loads(jwk) + elif isinstance(jwk, dict): + obj = jwk + else: + raise ValueError except ValueError: raise InvalidKeyError("Key is not valid JSON") diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py new file mode 100644 index 0000000..0966e40 --- /dev/null +++ b/jwt/api_jwk.py @@ -0,0 +1,72 @@ +import json + +from .algorithms import get_default_algorithms +from .exceptions import PyJWKError, PyJWKSetError + + +class PyJWK: + def __init__(self, jwk_data, algorithm=None): + self._algorithms = get_default_algorithms() + self._jwk_data = jwk_data + + if not algorithm and isinstance(self._jwk_data, dict): + algorithm = self._jwk_data.get("alg", None) + + if not algorithm: + raise PyJWKError( + "Unable to find a algorithm for key: %s" % self._jwk_data + ) + + self.Algorithm = self._algorithms.get(algorithm) + + if not self.Algorithm: + raise PyJWKError( + "Unable to find a algorithm for key: %s" % self._jwk_data + ) + + self.key = self.Algorithm.from_jwk(self._jwk_data) + + @staticmethod + def from_dict(obj, algorithm=None): + return PyJWK(obj, algorithm) + + @staticmethod + def from_json(data, algorithm=None): + obj = json.loads(data) + return PyJWK.from_dict(obj, algorithm) + + @property + def key_type(self): + return self._jwk_data.get("kty", None) + + @property + def key_id(self): + return self._jwk_data.get("kid", None) + + @property + def public_key_use(self): + return self._jwk_data.get("use", None) + + +class PyJWKSet: + def __init__(self, keys): + self.keys = [] + + if not keys or not isinstance(keys, list): + raise PyJWKSetError("Invalid JWK Set value") + + if len(keys) == 0: + raise PyJWKSetError("The JWK Set did not contain any keys") + + for key in keys: + self.keys.append(PyJWK(key)) + + @staticmethod + def from_dict(obj): + keys = obj.get("keys", []) + return PyJWKSet(keys) + + @staticmethod + def from_json(data): + obj = json.loads(data) + return PyJWKSet.from_dict(obj) diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 84465c1..63b004f 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -137,6 +137,7 @@ class PyJWS: verify=True, # type: bool algorithms=None, # type: List[str] options=None, # type: Dict + complete=False, # type: bool **kwargs ): @@ -166,6 +167,13 @@ class PyJWS: payload, signing_input, header, signature, key, algorithms ) + if complete: + return { + "payload": payload, + "header": header, + "signature": signature, + } + return payload def get_unverified_header(self, jwt): diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index b42a4d6..348dd74 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -79,6 +79,7 @@ class PyJWT(PyJWS): verify=True, # type: bool algorithms=None, # type: List[str] options=None, # type: Dict + complete=False, # type: bool **kwargs ): # type: (...) -> Dict[str, Any] @@ -100,11 +101,19 @@ class PyJWT(PyJWS): options.setdefault("verify_signature", verify) decoded = super().decode( - jwt, key=key, algorithms=algorithms, options=options, **kwargs + jwt, + key=key, + algorithms=algorithms, + options=options, + complete=complete, + **kwargs ) try: - payload = json.loads(decoded.decode("utf-8")) + if complete: + payload = json.loads(decoded["payload"].decode("utf-8")) + else: + payload = json.loads(decoded.decode("utf-8")) except ValueError as e: raise DecodeError("Invalid payload string: %s" % e) if not isinstance(payload, dict): @@ -114,6 +123,10 @@ class PyJWT(PyJWS): merged_options = merge_dict(self.options, options) self._validate_claims(payload, merged_options, **kwargs) + if complete: + decoded["payload"] = payload + return decoded + return payload def _validate_claims( diff --git a/jwt/exceptions.py b/jwt/exceptions.py index cd2ca2a..abe088b 100644 --- a/jwt/exceptions.py +++ b/jwt/exceptions.py @@ -54,6 +54,18 @@ class MissingRequiredClaimError(InvalidTokenError): return 'Token is missing the "%s" claim' % self.claim +class PyJWKError(PyJWTError): + pass + + +class PyJWKSetError(PyJWTError): + pass + + +class PyJWKClientError(PyJWTError): + pass + + # Compatibility aliases (deprecated) ExpiredSignature = ExpiredSignatureError InvalidAudience = InvalidAudienceError diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py new file mode 100644 index 0000000..c8f50e4 --- /dev/null +++ b/jwt/jwks_client.py @@ -0,0 +1,64 @@ +from .api_jwk import PyJWKSet +from .api_jwt import decode as decode_token +from .exceptions import PyJWKClientError + +try: + import requests + + has_requests = True +except ImportError: + has_requests = False + + +class PyJWKClient: + def __init__(self, uri): + if not has_requests: + raise PyJWKClientError( + "Missing dependencies for `PyJWKClient`. Run `pip install pyjwt[jwks-client]` to install dependencies." + ) + + self.uri = uri + + def fetch_data(self): + r = requests.get(self.uri) + return r.json() + + def get_jwk_set(self): + data = self.fetch_data() + return PyJWKSet.from_dict(data) + + def get_signing_keys(self): + jwk_set = self.get_jwk_set() + signing_keys = [] + + for jwk_set_key in jwk_set.keys: + if jwk_set_key.public_key_use == "sig" and jwk_set_key.key_id: + signing_keys.append(jwk_set_key) + + if len(signing_keys) == 0: + raise PyJWKClientError( + "The JWKS endpoint did not contain any signing keys" + ) + + return signing_keys + + def get_signing_key(self, kid): + signing_keys = self.get_signing_keys() + signing_key = None + + for key in signing_keys: + if key.key_id == kid: + signing_key = key + break + + if not signing_key: + raise PyJWKClientError( + 'Unable to find a signing key that matches: "{}"'.format(kid) + ) + + return signing_key + + def get_signing_key_from_jwt(self, token): + unverified = decode_token(token, complete=True, verify=False) + header = unverified.get("header") + return self.get_signing_key(header.get("kid")) diff --git a/pyproject.toml b/pyproject.toml index 49b360f..0552d2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,4 +28,4 @@ use_parentheses=true combine_as_imports=true known_first_party="jwt" -known_third_party=["Crypto", "cryptography", "ecdsa", "pytest", "setuptools", "sphinx_rtd_theme"] +known_third_party=["Crypto", "cryptography", "ecdsa", "pytest", "requests_mock", "setuptools", "sphinx_rtd_theme"] @@ -42,6 +42,7 @@ crypto = tests = pytest>=6.0.0,<7.0.0 coverage[toml]==5.0.4 + requests-mock>=1.7.0,<2.0.0 dev = sphinx sphinx-rtd-theme @@ -49,8 +50,11 @@ dev = cryptography>=1.4 pytest>=6.0.0,<7.0.0 coverage[toml]==5.0.4 + requests mypy pre-commit +jwks-client = + requests [options.packages.find] exclude = diff --git a/tests/test_api_jwk.py b/tests/test_api_jwk.py new file mode 100644 index 0000000..956133e --- /dev/null +++ b/tests/test_api_jwk.py @@ -0,0 +1,116 @@ +import json + +import pytest + +from jwt.api_jwk import PyJWK, PyJWKSet + +from .utils import key_path + +try: + from jwt.algorithms import RSAAlgorithm + + has_crypto = True +except ImportError: + has_crypto = False + + +class TestPyJWK: + @pytest.mark.skipif( + not has_crypto, + reason="Scenario requires cryptography to not be installed", + ) + def test_should_load_key_from_jwk_data_dict(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path("jwk_rsa_pub.json"), "r") as keyfile: + pub_key = algo.from_jwk(keyfile.read()) + + key_data_str = algo.to_jwk(pub_key) + key_data = json.loads(key_data_str) + + # TODO Should `to_jwk` set these? + key_data["alg"] = "RS256" + key_data["use"] = "sig" + key_data["kid"] = "keyid-abc123" + + jwk = PyJWK.from_dict(key_data) + + assert jwk.key_type == "RSA" + assert jwk.key_id == "keyid-abc123" + assert jwk.public_key_use == "sig" + + @pytest.mark.skipif( + not has_crypto, + reason="Scenario requires cryptography to not be installed", + ) + def test_should_load_key_from_jwk_data_json_string(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path("jwk_rsa_pub.json"), "r") as keyfile: + pub_key = algo.from_jwk(keyfile.read()) + + key_data_str = algo.to_jwk(pub_key) + key_data = json.loads(key_data_str) + + # TODO Should `to_jwk` set these? + key_data["alg"] = "RS256" + key_data["use"] = "sig" + key_data["kid"] = "keyid-abc123" + + jwk = PyJWK.from_json(json.dumps(key_data)) + + assert jwk.key_type == "RSA" + assert jwk.key_id == "keyid-abc123" + assert jwk.public_key_use == "sig" + + +class TestPyJWKSet: + @pytest.mark.skipif( + not has_crypto, + reason="Scenario requires cryptography to not be installed", + ) + def test_should_load_keys_from_jwk_data_dict(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path("jwk_rsa_pub.json"), "r") as keyfile: + pub_key = algo.from_jwk(keyfile.read()) + + key_data_str = algo.to_jwk(pub_key) + key_data = json.loads(key_data_str) + + # TODO Should `to_jwk` set these? + key_data["alg"] = "RS256" + key_data["use"] = "sig" + key_data["kid"] = "keyid-abc123" + + jwk_set = PyJWKSet.from_dict({"keys": [key_data]}) + jwk = jwk_set.keys[0] + + assert jwk.key_type == "RSA" + assert jwk.key_id == "keyid-abc123" + assert jwk.public_key_use == "sig" + + @pytest.mark.skipif( + not has_crypto, + reason="Scenario requires cryptography to not be installed", + ) + def test_should_load_keys_from_jwk_data_json_string(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path("jwk_rsa_pub.json"), "r") as keyfile: + pub_key = algo.from_jwk(keyfile.read()) + + key_data_str = algo.to_jwk(pub_key) + key_data = json.loads(key_data_str) + + # TODO Should `to_jwk` set these? + key_data["alg"] = "RS256" + key_data["use"] = "sig" + key_data["kid"] = "keyid-abc123" + + jwk_set = PyJWKSet.from_json(json.dumps({"keys": [key_data]})) + jwk = jwk_set.keys[0] + + assert jwk.key_type == "RSA" + assert jwk.key_id == "keyid-abc123" + assert jwk.public_key_use == "sig" diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py new file mode 100644 index 0000000..f4fb14a --- /dev/null +++ b/tests/test_jwks_client.py @@ -0,0 +1,113 @@ +import pytest +import requests_mock + +import jwt +from jwt import PyJWKClient +from jwt.api_jwk import PyJWK +from jwt.exceptions import PyJWKClientError + +from .test_algorithms import has_crypto + + +@pytest.fixture +def mocked_response(): + return { + "keys": [ + { + "alg": "RS256", + "kty": "RSA", + "use": "sig", + "n": "0wtlJRY9-ru61LmOgieeI7_rD1oIna9QpBMAOWw8wTuoIhFQFwcIi7MFB7IEfelCPj08vkfLsuFtR8cG07EE4uvJ78bAqRjMsCvprWp4e2p7hqPnWcpRpDEyHjzirEJle1LPpjLLVaSWgkbrVaOD0lkWkP1T1TkrOset_Obh8BwtO-Ww-UfrEwxTyz1646AGkbT2nL8PX0trXrmira8GnrCkFUgTUS61GoTdb9bCJ19PLX9Gnxw7J0BtR0GubopXq8KlI0ThVql6ZtVGN2dvmrCPAVAZleM5TVB61m0VSXvGWaF6_GeOhbFoyWcyUmFvzWhBm8Q38vWgsSI7oHTkEw", + "e": "AQAB", + "kid": "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw", + "x5t": "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw", + "x5c": [ + "MIIDBzCCAe+gAwIBAgIJNtD9Ozi6j2jJMA0GCSqGSIb3DQEBCwUAMCExHzAdBgNVBAMTFmRldi04N2V2eDlydS5hdXRoMC5jb20wHhcNMTkwNjIwMTU0NDU4WhcNMzMwMjI2MTU0NDU4WjAhMR8wHQYDVQQDExZkZXYtODdldng5cnUuYXV0aDAuY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0wtlJRY9+ru61LmOgieeI7/rD1oIna9QpBMAOWw8wTuoIhFQFwcIi7MFB7IEfelCPj08vkfLsuFtR8cG07EE4uvJ78bAqRjMsCvprWp4e2p7hqPnWcpRpDEyHjzirEJle1LPpjLLVaSWgkbrVaOD0lkWkP1T1TkrOset/Obh8BwtO+Ww+UfrEwxTyz1646AGkbT2nL8PX0trXrmira8GnrCkFUgTUS61GoTdb9bCJ19PLX9Gnxw7J0BtR0GubopXq8KlI0ThVql6ZtVGN2dvmrCPAVAZleM5TVB61m0VSXvGWaF6/GeOhbFoyWcyUmFvzWhBm8Q38vWgsSI7oHTkEwIDAQABo0IwQDAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBQlGXpmYaXFB7Q3eG69Uhjd4cFp/jAOBgNVHQ8BAf8EBAMCAoQwDQYJKoZIhvcNAQELBQADggEBAIzQOF/h4T5WWAdjhcIwdNS7hS2Deq+UxxkRv+uavj6O9mHLuRG1q5onvSFShjECXaYT6OGibn7Ufw/JSm3+86ZouMYjBEqGh4OvWRkwARy1YTWUVDGpT2HAwtIq3lfYvhe8P4VfZByp1N4lfn6X2NcJflG+Q+mfXNmRFyyft3Oq51PCZyyAkU7bTun9FmMOyBtmJvQjZ8RXgBLvu9nUcZB8yTVoeUEg4cLczQlli/OkiFXhWgrhVr8uF0/9klslMFXtm78iYSgR8/oC+k1pSNd1+ESSt7n6+JiAQ2Co+ZNKta7LTDGAjGjNDymyoCrZpeuYQwwnHYEHu/0khjAxhXo=" + ], + } + ] + } + + +@pytest.mark.skipif( + not has_crypto, reason="Not supported without cryptography library" +) +class TestPyJWKClient: + def test_get_jwk_set(self, mocked_response): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + with requests_mock.mock() as m: + m.get(url, json=mocked_response) + jwks_client = PyJWKClient(url) + jwk_set = jwks_client.get_jwk_set() + + assert len(jwk_set.keys) == 1 + + def test_get_signing_keys(self, mocked_response): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + with requests_mock.mock() as m: + m.get(url, json=mocked_response) + jwks_client = PyJWKClient(url) + signing_keys = jwks_client.get_signing_keys() + + assert len(signing_keys) == 1 + assert isinstance(signing_keys[0], PyJWK) + + def test_get_signing_keys_raises_if_none_found(self, mocked_response): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + with requests_mock.mock() as m: + mocked_key = mocked_response["keys"][0].copy() + mocked_key["use"] = "enc" + response = {"keys": [mocked_key]} + m.get(url, json=response) + jwks_client = PyJWKClient(url) + + with pytest.raises(PyJWKClientError) as exc: + jwks_client.get_signing_keys() + + assert "The JWKS endpoint did not contain any signing keys" in str( + exc.value + ) + + def test_get_signing_key(self, mocked_response): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" + + with requests_mock.mock() as m: + m.get(url, json=mocked_response) + jwks_client = PyJWKClient(url) + signing_key = jwks_client.get_signing_key(kid) + + assert isinstance(signing_key, PyJWK) + assert signing_key.key_type == "RSA" + assert signing_key.key_id == kid + assert signing_key.public_key_use == "sig" + + def test_get_signing_key_from_jwt(self, mocked_response): + token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik5FRTFRVVJCT1RNNE16STVSa0ZETlRZeE9UVTFNRGcyT0Rnd1EwVXpNVGsxUWpZeVJrUkZRdyJ9.eyJpc3MiOiJodHRwczovL2Rldi04N2V2eDlydS5hdXRoMC5jb20vIiwic3ViIjoiYVc0Q2NhNzl4UmVMV1V6MGFFMkg2a0QwTzNjWEJWdENAY2xpZW50cyIsImF1ZCI6Imh0dHBzOi8vZXhwZW5zZXMtYXBpIiwiaWF0IjoxNTcyMDA2OTU0LCJleHAiOjE1NzIwMDY5NjQsImF6cCI6ImFXNENjYTc5eFJlTFdVejBhRTJINmtEME8zY1hCVnRDIiwiZ3R5IjoiY2xpZW50LWNyZWRlbnRpYWxzIn0.PUxE7xn52aTCohGiWoSdMBZGiYAHwE5FYie0Y1qUT68IHSTXwXVd6hn02HTah6epvHHVKA2FqcFZ4GGv5VTHEvYpeggiiZMgbxFrmTEY0csL6VNkX1eaJGcuehwQCRBKRLL3zKmA5IKGy5GeUnIbpPHLHDxr-GXvgFzsdsyWlVQvPX2xjeaQ217r2PtxDeqjlf66UYl6oY6AqNS8DH3iryCvIfCcybRZkc_hdy-6ZMoKT6Piijvk_aXdm7-QQqKJFHLuEqrVSOuBqqiNfVrG27QzAPuPOxvfXTVLXL2jek5meH6n-VWgrBdoMFH93QEszEDowDAEhQPHVs0xj7SIzA" + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + with requests_mock.mock() as m: + m.get(url, json=mocked_response) + jwks_client = PyJWKClient(url) + signing_key = jwks_client.get_signing_key_from_jwt(token) + + data = jwt.decode( + token, + signing_key.key, + algorithms=["RS256"], + audience="https://expenses-api", + options={"verify_exp": False}, + ) + + assert data == { + "iss": "https://dev-87evx9ru.auth0.com/", + "sub": "aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC@clients", + "aud": "https://expenses-api", + "iat": 1572006954, + "exp": 1572006964, + "azp": "aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC", + "gty": "client-credentials", + } |