summaryrefslogtreecommitdiff
path: root/jwt/algorithms.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r--jwt/algorithms.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
new file mode 100644
index 0000000..89ea75b
--- /dev/null
+++ b/jwt/algorithms.py
@@ -0,0 +1,200 @@
+import hashlib
+import hmac
+
+from jwt import register_algorithm
+from jwt.compat import constant_time_compare, string_types, text_type
+
+try:
+ from cryptography.hazmat.primitives import interfaces, hashes
+ from cryptography.hazmat.primitives.serialization import (
+ load_pem_private_key, load_pem_public_key, load_ssh_public_key
+ )
+ from cryptography.hazmat.primitives.asymmetric import ec, padding
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.exceptions import InvalidSignature
+
+ has_crypto = True
+except ImportError:
+ has_crypto = False
+
+
+def _register_default_algorithms():
+ """ Registers the algorithms that are implemented by the library """
+ register_algorithm('none', NoneAlgorithm())
+ register_algorithm('HS256', HMACAlgorithm(hashlib.sha256))
+ register_algorithm('HS384', HMACAlgorithm(hashlib.sha384))
+ register_algorithm('HS512', HMACAlgorithm(hashlib.sha512))
+
+ if has_crypto:
+ register_algorithm('RS256', RSAAlgorithm(hashes.SHA256()))
+ register_algorithm('RS384', RSAAlgorithm(hashes.SHA384()))
+ register_algorithm('RS512', RSAAlgorithm(hashes.SHA512()))
+
+ register_algorithm('ES256', ECAlgorithm(hashes.SHA256()))
+ register_algorithm('ES384', ECAlgorithm(hashes.SHA384()))
+ register_algorithm('ES512', ECAlgorithm(hashes.SHA512()))
+
+
+class Algorithm(object):
+ """ The interface for an algorithm used to sign and verify JWTs """
+ def prepare_key(self, key):
+ """
+ Performs necessary validation and conversions on the key and returns
+ the key value in the proper format for sign() and verify()
+ """
+ raise NotImplementedError
+
+ def sign(self, msg, key):
+ """
+ Returns a digital signature for the specified message using the
+ specified key value
+ """
+ raise NotImplementedError
+
+ def verify(self, msg, key, sig):
+ """
+ Verifies that the specified digital signature is valid for the specified
+ message and key values.
+ """
+ raise NotImplementedError
+
+
+class NoneAlgorithm(Algorithm):
+ """
+ Placeholder for use when no signing or verification operations are required
+ """
+ def prepare_key(self, key):
+ return None
+
+ def sign(self, msg, key):
+ return b''
+
+ def verify(self, msg, key):
+ return True
+
+
+class HMACAlgorithm(Algorithm):
+ """
+ Performs signing and verification operations using HMAC and the specified
+ hash function
+ """
+ def __init__(self, hash_alg):
+ self.hash_alg = hash_alg
+
+ def prepare_key(self, key):
+ if not isinstance(key, string_types) and not isinstance(key, bytes):
+ raise TypeError('Expecting a string- or bytes-formatted key.')
+
+ if isinstance(key, text_type):
+ key = key.encode('utf-8')
+
+ return key
+
+ def sign(self, msg, key):
+ return hmac.new(key, msg, self.hash_alg).digest()
+
+ def verify(self, msg, key, sig):
+ return constant_time_compare(sig, self.sign(msg, key))
+
+if has_crypto:
+
+ class RSAAlgorithm(Algorithm):
+ """
+ Performs signing and verification operations using RSASSA-PKCS-v1_5 and
+ the specified hash function
+ """
+
+ def __init__(self, hash_alg):
+ self.hash_alg = hash_alg
+
+ def prepare_key(self, key):
+ if isinstance(key, interfaces.RSAPrivateKey) or \
+ isinstance(key, interfaces.RSAPublicKey):
+ return key
+
+ if isinstance(key, string_types):
+ if isinstance(key, text_type):
+ key = key.encode('utf-8')
+
+ try:
+ if key.startswith(b'ssh-rsa'):
+ key = load_ssh_public_key(key, backend=default_backend())
+ else:
+ key = load_pem_private_key(key, password=None, backend=default_backend())
+ except ValueError:
+ key = load_pem_public_key(key, backend=default_backend())
+ else:
+ raise TypeError('Expecting a PEM-formatted key.')
+
+ return key
+
+ def sign(self, msg, key):
+ signer = key.signer(
+ padding.PKCS1v15(),
+ self.hash_alg
+ )
+
+ signer.update(msg)
+ return signer.finalize()
+
+ def verify(self, msg, key, sig):
+ verifier = key.verifier(
+ sig,
+ padding.PKCS1v15(),
+ self.hash_alg
+ )
+
+ verifier.update(msg)
+
+ try:
+ verifier.verify()
+ return True
+ except InvalidSignature:
+ return False
+
+ class ECAlgorithm(Algorithm):
+ """
+ Performs signing and verification operations using ECDSA and the
+ specified hash function
+ """
+ def __init__(self, hash_alg):
+ self.hash_alg = hash_alg
+
+ def prepare_key(self, key):
+ if isinstance(key, interfaces.EllipticCurvePrivateKey) or \
+ isinstance(key, interfaces.EllipticCurvePublicKey):
+ return key
+
+ if isinstance(key, string_types):
+ if isinstance(key, text_type):
+ key = key.encode('utf-8')
+
+ # Attempt to load key. We don't know if it's
+ # a Signing Key or a Verifying Key, so we try
+ # the Verifying Key first.
+ try:
+ key = load_pem_public_key(key, backend=default_backend())
+ except ValueError:
+ key = load_pem_private_key(key, password=None, backend=default_backend())
+
+ else:
+ raise TypeError('Expecting a PEM-formatted key.')
+
+ return key
+
+ def sign(self, msg, key):
+ signer = key.signer(ec.ECDSA(self.hash_alg))
+
+ signer.update(msg)
+ return signer.finalize()
+
+ def verify(self, msg, key, sig):
+ verifier = key.verifier(sig, ec.ECDSA(self.hash_alg))
+
+ verifier.update(msg)
+
+ try:
+ verifier.verify()
+ return True
+ except InvalidSignature:
+ return False