summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAjitomi, Daisuke <ajitomi@gmail.com>2021-03-22 17:14:46 +0900
committerGitHub <noreply@github.com>2021-03-22 14:14:46 +0600
commit81d5829183785b875253162687e59f3ff81674c3 (patch)
tree840b37c6fce82bef564c783417363d05d031e417
parentf6cd4ac12f039cc63f812adbd78b6224a834cc14 (diff)
downloadpyjwt-81d5829183785b875253162687e59f3ff81674c3.tar.gz
Support JWK without alg. (#624)
* Support JWK without alg. * Make kty mandatory on PyJWK. * Add tests for kty=OKP. * Add tests for OKP-type JWK. * Add support for ES256K.
-rw-r--r--jwt/api_jwk.py33
-rw-r--r--tests/test_api_jwk.py162
2 files changed, 192 insertions, 3 deletions
diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py
index 3c10e07..a0f6364 100644
--- a/jwt/api_jwk.py
+++ b/jwt/api_jwk.py
@@ -1,7 +1,7 @@
import json
from .algorithms import get_default_algorithms
-from .exceptions import PyJWKError, PyJWKSetError
+from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
class PyJWK:
@@ -9,11 +9,40 @@ class PyJWK:
self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data
+ kty = self._jwk_data.get("kty", None)
+ if not kty:
+ raise InvalidKeyError("kty is not found: %s" % self._jwk_data)
+
if not algorithm and isinstance(self._jwk_data, dict):
algorithm = self._jwk_data.get("alg", None)
if not algorithm:
- raise PyJWKError("Unable to find a algorithm for key: %s" % self._jwk_data)
+ # Determine alg with kty (and crv).
+ crv = self._jwk_data.get("crv", None)
+ if kty == "EC":
+ if crv == "P-256" or not crv:
+ algorithm = "ES256"
+ elif crv == "P-384":
+ algorithm = "ES384"
+ elif crv == "P-521":
+ algorithm = "ES512"
+ elif crv == "secp256k1":
+ algorithm = "ES256K"
+ else:
+ raise InvalidKeyError("Unsupported crv: %s" % crv)
+ elif kty == "RSA":
+ algorithm = "RS256"
+ elif kty == "oct":
+ algorithm = "HS256"
+ elif kty == "OKP":
+ if not crv:
+ raise InvalidKeyError("crv is not found: %s" % self._jwk_data)
+ if crv == "Ed25519":
+ algorithm = "EdDSA"
+ else:
+ raise InvalidKeyError("Unsupported crv: %s" % crv)
+ else:
+ raise InvalidKeyError("Unsupported kty: %s" % kty)
self.Algorithm = self._algorithms.get(algorithm)
diff --git a/tests/test_api_jwk.py b/tests/test_api_jwk.py
index 102af87..e0787f4 100644
--- a/tests/test_api_jwk.py
+++ b/tests/test_api_jwk.py
@@ -1,12 +1,20 @@
import json
+import pytest
+
from jwt.algorithms import has_crypto
from jwt.api_jwk import PyJWK, PyJWKSet
+from jwt.exceptions import InvalidKeyError, PyJWKError
from .utils import crypto_required, key_path
if has_crypto:
- from jwt.algorithms import RSAAlgorithm
+ from jwt.algorithms import (
+ ECAlgorithm,
+ Ed25519Algorithm,
+ HMACAlgorithm,
+ RSAAlgorithm,
+ )
class TestPyJWK:
@@ -52,6 +60,158 @@ class TestPyJWK:
assert jwk.key_id == "keyid-abc123"
assert jwk.public_key_use == "sig"
+ @crypto_required
+ def test_should_load_key_without_alg_from_dict(self):
+
+ with open(key_path("jwk_rsa_pub.json")) as keyfile:
+ key_data = json.loads(keyfile.read())
+
+ jwk = PyJWK.from_dict(key_data)
+
+ assert jwk.key_type == "RSA"
+ assert isinstance(jwk.Algorithm, RSAAlgorithm)
+ assert jwk.Algorithm.hash_alg == RSAAlgorithm.SHA256
+
+ @crypto_required
+ def test_should_load_key_from_dict_with_algorithm(self):
+
+ with open(key_path("jwk_rsa_pub.json")) as keyfile:
+ key_data = json.loads(keyfile.read())
+
+ jwk = PyJWK.from_dict(key_data, algorithm="RS256")
+
+ assert jwk.key_type == "RSA"
+ assert isinstance(jwk.Algorithm, RSAAlgorithm)
+ assert jwk.Algorithm.hash_alg == RSAAlgorithm.SHA256
+
+ @crypto_required
+ def test_should_load_key_ec_p256_from_dict(self):
+
+ with open(key_path("jwk_ec_pub_P-256.json")) as keyfile:
+ key_data = json.loads(keyfile.read())
+
+ jwk = PyJWK.from_dict(key_data)
+
+ assert jwk.key_type == "EC"
+ assert isinstance(jwk.Algorithm, ECAlgorithm)
+ assert jwk.Algorithm.hash_alg == ECAlgorithm.SHA256
+
+ @crypto_required
+ def test_should_load_key_ec_p384_from_dict(self):
+
+ with open(key_path("jwk_ec_pub_P-384.json")) as keyfile:
+ key_data = json.loads(keyfile.read())
+
+ jwk = PyJWK.from_dict(key_data)
+
+ assert jwk.key_type == "EC"
+ assert isinstance(jwk.Algorithm, ECAlgorithm)
+ assert jwk.Algorithm.hash_alg == ECAlgorithm.SHA384
+
+ @crypto_required
+ def test_should_load_key_ec_p521_from_dict(self):
+
+ with open(key_path("jwk_ec_pub_P-521.json")) as keyfile:
+ key_data = json.loads(keyfile.read())
+
+ jwk = PyJWK.from_dict(key_data)
+
+ assert jwk.key_type == "EC"
+ assert isinstance(jwk.Algorithm, ECAlgorithm)
+ assert jwk.Algorithm.hash_alg == ECAlgorithm.SHA512
+
+ @crypto_required
+ def test_should_load_key_ec_secp256k1_from_dict(self):
+
+ with open(key_path("jwk_ec_pub_secp256k1.json")) as keyfile:
+ key_data = json.loads(keyfile.read())
+
+ jwk = PyJWK.from_dict(key_data)
+
+ assert jwk.key_type == "EC"
+ assert isinstance(jwk.Algorithm, ECAlgorithm)
+ assert jwk.Algorithm.hash_alg == ECAlgorithm.SHA256
+
+ @crypto_required
+ def test_should_load_key_hmac_from_dict(self):
+
+ with open(key_path("jwk_hmac.json")) as keyfile:
+ key_data = json.loads(keyfile.read())
+
+ jwk = PyJWK.from_dict(key_data)
+
+ assert jwk.key_type == "oct"
+ assert isinstance(jwk.Algorithm, HMACAlgorithm)
+ assert jwk.Algorithm.hash_alg == HMACAlgorithm.SHA256
+
+ @crypto_required
+ def test_should_load_key_hmac_without_alg_from_dict(self):
+
+ with open(key_path("jwk_hmac.json")) as keyfile:
+ key_data = json.loads(keyfile.read())
+
+ del key_data["alg"]
+ jwk = PyJWK.from_dict(key_data)
+
+ assert jwk.key_type == "oct"
+ assert isinstance(jwk.Algorithm, HMACAlgorithm)
+ assert jwk.Algorithm.hash_alg == HMACAlgorithm.SHA256
+
+ @crypto_required
+ def test_should_load_key_okp_without_alg_from_dict(self):
+
+ with open(key_path("jwk_okp_pub_Ed25519.json")) as keyfile:
+ key_data = json.loads(keyfile.read())
+
+ jwk = PyJWK.from_dict(key_data)
+
+ assert jwk.key_type == "OKP"
+ assert isinstance(jwk.Algorithm, Ed25519Algorithm)
+
+ @crypto_required
+ def test_from_dict_should_throw_exception_if_arg_is_invalid(self):
+
+ with open(key_path("jwk_rsa_pub.json")) as keyfile:
+ valid_rsa_pub = json.loads(keyfile.read())
+ with open(key_path("jwk_ec_pub_P-256.json")) as keyfile:
+ valid_ec_pub = json.loads(keyfile.read())
+ with open(key_path("jwk_okp_pub_Ed25519.json")) as keyfile:
+ valid_okp_pub = json.loads(keyfile.read())
+
+ # Unknown algorithm
+ with pytest.raises(PyJWKError):
+ PyJWK.from_dict(valid_rsa_pub, algorithm="unknown")
+
+ # Missing kty
+ v = valid_rsa_pub.copy()
+ del v["kty"]
+ with pytest.raises(InvalidKeyError):
+ PyJWK.from_dict(v)
+
+ # Unknown kty
+ v = valid_rsa_pub.copy()
+ v["kty"] = "unknown"
+ with pytest.raises(InvalidKeyError):
+ PyJWK.from_dict(v)
+
+ # Unknown EC crv
+ v = valid_ec_pub.copy()
+ v["crv"] = "unknown"
+ with pytest.raises(InvalidKeyError):
+ PyJWK.from_dict(v)
+
+ # Unknown OKP crv
+ v = valid_okp_pub.copy()
+ v["crv"] = "unknown"
+ with pytest.raises(InvalidKeyError):
+ PyJWK.from_dict(v)
+
+ # Missing OKP crv
+ v = valid_okp_pub.copy()
+ del v["crv"]
+ with pytest.raises(InvalidKeyError):
+ PyJWK.from_dict(v)
+
class TestPyJWKSet:
@crypto_required