summaryrefslogtreecommitdiff
path: root/jwt/algorithms.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r--jwt/algorithms.py99
1 files changed, 57 insertions, 42 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 4fae441..c809f15 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,8 +1,10 @@
import hashlib
import hmac
import json
+from typing import Any, Dict, Union
from .exceptions import InvalidKeyError
+from .types import JWKDict
from .utils import (
base64url_decode,
base64url_encode,
@@ -20,10 +22,18 @@ try:
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
- from cryptography.hazmat.primitives.asymmetric import ec, padding
+ from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.ec import (
+ ECDSA,
+ SECP256K1,
+ SECP256R1,
+ SECP384R1,
+ SECP521R1,
+ EllipticCurve,
EllipticCurvePrivateKey,
+ EllipticCurvePrivateNumbers,
EllipticCurvePublicKey,
+ EllipticCurvePublicNumbers,
)
from cryptography.hazmat.primitives.asymmetric.ed448 import (
Ed448PrivateKey,
@@ -73,7 +83,7 @@ requires_cryptography = {
}
-def get_default_algorithms():
+def get_default_algorithms() -> Dict[str, "Algorithm"]:
"""
Returns the algorithms that are implemented by the library.
"""
@@ -130,25 +140,29 @@ class Algorithm:
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
- return digest.finalize()
+ return bytes(digest.finalize())
else:
- return hash_alg(bytestr).digest()
+ return bytes(hash_alg(bytestr).digest())
- def prepare_key(self, key):
+ # TODO: all key-related `Any`s in this class should optimally be made
+ # variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
+ # that may still be poorly supported.
+
+ def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
raise NotImplementedError
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
raise NotImplementedError
- def verify(self, msg, key, sig):
+ def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
@@ -156,14 +170,14 @@ class Algorithm:
raise NotImplementedError
@staticmethod
- def to_jwk(key_obj):
+ def to_jwk(key_obj) -> JWKDict:
"""
Serializes a given RSA key into a JWK
"""
raise NotImplementedError
@staticmethod
- def from_jwk(jwk):
+ def from_jwk(jwk: JWKDict):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
@@ -202,7 +216,7 @@ class HMACAlgorithm(Algorithm):
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
- def __init__(self, hash_alg):
+ def __init__(self, hash_alg) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key):
@@ -242,7 +256,7 @@ class HMACAlgorithm(Algorithm):
return base64url_decode(obj["k"])
- def sign(self, msg, key):
+ def sign(self, msg: bytes, key: bytes) -> bytes:
return hmac.new(key, msg, self.hash_alg).digest()
def verify(self, msg, key, sig):
@@ -261,7 +275,7 @@ if has_crypto:
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
- def __init__(self, hash_alg):
+ def __init__(self, hash_alg) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key):
@@ -271,16 +285,15 @@ if has_crypto:
if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")
- key = force_bytes(key)
+ key_bytes = force_bytes(key)
try:
- if key.startswith(b"ssh-rsa"):
- key = load_ssh_public_key(key)
+ if key_bytes.startswith(b"ssh-rsa"):
+ return load_ssh_public_key(key_bytes)
else:
- key = load_pem_private_key(key, password=None)
+ return load_pem_private_key(key_bytes, password=None)
except ValueError:
- key = load_pem_public_key(key)
- return key
+ return load_pem_public_key(key_bytes)
@staticmethod
def to_jwk(key_obj):
@@ -383,12 +396,10 @@ if has_crypto:
return numbers.private_key()
elif "n" in obj and "e" in obj:
# Public key
- numbers = RSAPublicNumbers(
+ return RSAPublicNumbers(
from_base64url_uint(obj["e"]),
from_base64url_uint(obj["n"]),
- )
-
- return numbers.public_key()
+ ).public_key()
else:
raise InvalidKeyError("Not a public or private key")
@@ -412,7 +423,7 @@ if has_crypto:
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
- def __init__(self, hash_alg):
+ def __init__(self, hash_alg) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key):
@@ -422,18 +433,18 @@ if has_crypto:
if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")
- key = force_bytes(key)
+ key_bytes = force_bytes(key)
# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
# the Verifying Key first.
try:
- if key.startswith(b"ecdsa-sha2-"):
- key = load_ssh_public_key(key)
+ if key_bytes.startswith(b"ecdsa-sha2-"):
+ key = load_ssh_public_key(key_bytes)
else:
- key = load_pem_public_key(key)
+ key = load_pem_public_key(key_bytes)
except ValueError:
- key = load_pem_private_key(key, password=None)
+ key = load_pem_private_key(key_bytes, password=None)
# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
@@ -444,7 +455,7 @@ if has_crypto:
return key
def sign(self, msg, key):
- der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
+ der_sig = key.sign(msg, ECDSA(self.hash_alg()))
return der_to_raw_signature(der_sig, key.curve)
@@ -457,7 +468,7 @@ if has_crypto:
try:
if isinstance(key, EllipticCurvePrivateKey):
key = key.public_key()
- key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
+ key.verify(der_sig, msg, ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
@@ -472,13 +483,13 @@ if has_crypto:
else:
raise InvalidKeyError("Not a public or private key")
- if isinstance(key_obj.curve, ec.SECP256R1):
+ if isinstance(key_obj.curve, SECP256R1):
crv = "P-256"
- elif isinstance(key_obj.curve, ec.SECP384R1):
+ elif isinstance(key_obj.curve, SECP384R1):
crv = "P-384"
- elif isinstance(key_obj.curve, ec.SECP521R1):
+ elif isinstance(key_obj.curve, SECP521R1):
crv = "P-521"
- elif isinstance(key_obj.curve, ec.SECP256K1):
+ elif isinstance(key_obj.curve, SECP256K1):
crv = "secp256k1"
else:
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
@@ -498,7 +509,9 @@ if has_crypto:
return json.dumps(obj)
@staticmethod
- def from_jwk(jwk):
+ def from_jwk(
+ jwk: Any,
+ ) -> Union[EllipticCurvePublicKey, EllipticCurvePrivateKey]:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
@@ -519,24 +532,26 @@ if has_crypto:
y = base64url_decode(obj.get("y"))
curve = obj.get("crv")
+ curve_obj: EllipticCurve
+
if curve == "P-256":
if len(x) == len(y) == 32:
- curve_obj = ec.SECP256R1()
+ curve_obj = SECP256R1()
else:
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
elif curve == "P-384":
if len(x) == len(y) == 48:
- curve_obj = ec.SECP384R1()
+ curve_obj = SECP384R1()
else:
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
elif curve == "P-521":
if len(x) == len(y) == 66:
- curve_obj = ec.SECP521R1()
+ curve_obj = SECP521R1()
else:
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
elif curve == "secp256k1":
if len(x) == len(y) == 32:
- curve_obj = ec.SECP256K1()
+ curve_obj = SECP256K1()
else:
raise InvalidKeyError(
"Coords should be 32 bytes for curve secp256k1"
@@ -544,7 +559,7 @@ if has_crypto:
else:
raise InvalidKeyError(f"Invalid curve: {curve}")
- public_numbers = ec.EllipticCurvePublicNumbers(
+ public_numbers = EllipticCurvePublicNumbers(
x=int.from_bytes(x, byteorder="big"),
y=int.from_bytes(y, byteorder="big"),
curve=curve_obj,
@@ -559,7 +574,7 @@ if has_crypto:
"D should be {} bytes for curve {}", len(x), curve
)
- return ec.EllipticCurvePrivateNumbers(
+ return EllipticCurvePrivateNumbers(
int.from_bytes(d, byteorder="big"), public_numbers
).private_key()
@@ -600,7 +615,7 @@ if has_crypto:
This class requires ``cryptography>=2.6`` to be installed.
"""
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs) -> None:
pass
def prepare_key(self, key):