summaryrefslogtreecommitdiff
path: root/jwt/algorithms.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r--jwt/algorithms.py39
1 files changed, 18 insertions, 21 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 739df80..46a1a53 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -9,6 +9,8 @@ from .utils import (
der_to_raw_signature,
force_bytes,
from_base64url_uint,
+ is_pem_format,
+ is_ssh_key,
raw_to_der_signature,
to_base64url_uint,
)
@@ -183,14 +185,7 @@ class HMACAlgorithm(Algorithm):
def prepare_key(self, key):
key = force_bytes(key)
- invalid_strings = [
- b"-----BEGIN PUBLIC KEY-----",
- b"-----BEGIN CERTIFICATE-----",
- b"-----BEGIN RSA PUBLIC KEY-----",
- b"ssh-rsa",
- ]
-
- if any(string_value in key for string_value in invalid_strings):
+ if is_pem_format(key) or is_ssh_key(key):
raise InvalidKeyError(
"The specified key is an asymmetric key or x509 certificate and"
" should not be used as an HMAC secret."
@@ -551,26 +546,28 @@ if has_crypto:
pass
def prepare_key(self, key):
-
- if isinstance(
- key,
- (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
- ):
- return key
-
if isinstance(key, (bytes, str)):
if isinstance(key, str):
key = key.encode("utf-8")
str_key = key.decode("utf-8")
if "-----BEGIN PUBLIC" in str_key:
- return load_pem_public_key(key)
- if "-----BEGIN PRIVATE" in str_key:
- return load_pem_private_key(key, password=None)
- if str_key[0:4] == "ssh-":
- return load_ssh_public_key(key)
+ key = load_pem_public_key(key)
+ elif "-----BEGIN PRIVATE" in str_key:
+ key = load_pem_private_key(key, password=None)
+ elif str_key[0:4] == "ssh-":
+ key = load_ssh_public_key(key)
- raise TypeError("Expecting a PEM-formatted or OpenSSH key.")
+ # Explicit check the key to prevent confusing errors from cryptography
+ if not isinstance(
+ key,
+ (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
+ ):
+ raise InvalidKeyError(
+ "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
+ )
+
+ return key
def sign(self, msg, key):
"""