summaryrefslogtreecommitdiff
path: root/jwt/algorithms.py
diff options
context:
space:
mode:
authorMark Adams <mark@markadams.me>2015-01-06 08:00:28 -0600
committerMark Adams <mark@markadams.me>2015-01-18 10:26:25 -0600
commit1da2d4a52d55d64f77d7c6d6b52cdba555dc0e0b (patch)
tree4364dabb3fe0a6f3608e67f901e0dc7239b2f967 /jwt/algorithms.py
parent0afba10cf16834e154a59280de089c30de3d9a61 (diff)
downloadpyjwt-1da2d4a52d55d64f77d7c6d6b52cdba555dc0e0b.tar.gz
Fixes #70. Refactored all HMAC, RSA, and EC code into seperate classes in the algorithms module
Added register_algorithm to add new algorithms.
Diffstat (limited to 'jwt/algorithms.py')
-rw-r--r--jwt/algorithms.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
new file mode 100644
index 0000000..2f6f113
--- /dev/null
+++ b/jwt/algorithms.py
@@ -0,0 +1,200 @@
+import hashlib
+import hmac
+import sys
+
+from jwt import register_algorithm
+
+if sys.version_info >= (3, 0, 0):
+ unicode = str
+ basestring = str
+
+try:
+ from cryptography.hazmat.primitives import interfaces, hashes
+ from cryptography.hazmat.primitives.serialization import (
+ load_pem_private_key, load_pem_public_key, load_ssh_public_key
+ )
+ from cryptography.hazmat.primitives.asymmetric import ec, padding
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.exceptions import InvalidSignature
+
+ has_crypto = True
+except ImportError:
+ has_crypto = False
+
+
+def _register_default_algorithms():
+ register_algorithm('none', NoneAlgorithm())
+ register_algorithm('HS256', HMACAlgorithm(hashlib.sha256))
+ register_algorithm('HS384', HMACAlgorithm(hashlib.sha384))
+ register_algorithm('HS512', HMACAlgorithm(hashlib.sha512))
+
+ if has_crypto:
+ register_algorithm('RS256', RSAAlgorithm(hashes.SHA256()))
+ register_algorithm('RS384', RSAAlgorithm(hashes.SHA384()))
+ register_algorithm('RS512', RSAAlgorithm(hashes.SHA512()))
+
+ register_algorithm('ES256', ECAlgorithm(hashes.SHA256()))
+ register_algorithm('ES384', ECAlgorithm(hashes.SHA384()))
+ register_algorithm('ES512', ECAlgorithm(hashes.SHA512()))
+
+
+class Algorithm(object):
+ def prepare_key(self, key):
+ pass
+
+ def sign(self, msg, key):
+ pass
+
+ def verify(self, msg, key, sig):
+ pass
+
+
+class NoneAlgorithm(Algorithm):
+ def prepare_key(self, key):
+ return None
+
+ def sign(self, msg, key):
+ return b''
+
+ def verify(self, msg, key):
+ return True
+
+
+class HMACAlgorithm(Algorithm):
+ def __init__(self, hash_alg):
+ self.hash_alg = hash_alg
+
+ def prepare_key(self, key):
+ if not isinstance(key, basestring) and not isinstance(key, bytes):
+ raise TypeError('Expecting a string- or bytes-formatted key.')
+
+ if isinstance(key, unicode):
+ key = key.encode('utf-8')
+
+ return key
+
+ def sign(self, msg, key):
+ return hmac.new(key, msg, self.hash_alg).digest()
+
+ def verify(self, msg, key, sig):
+ return self._constant_time_compare(sig, self.sign(msg, key))
+
+ try:
+ _constant_time_compare = staticmethod(hmac.compare_digest)
+ except AttributeError:
+ # Fallback for Python < 2.7.7 and Python < 3.3
+ @staticmethod
+ def constant_time_compare(val1, val2):
+ """
+ Returns True if the two strings are equal, False otherwise.
+
+ The time taken is independent of the number of characters that match.
+ """
+ if len(val1) != len(val2):
+ return False
+
+ result = 0
+
+ if sys.version_info >= (3, 0, 0):
+ # Bytes are numbers
+ for x, y in zip(val1, val2):
+ result |= x ^ y
+ else:
+ for x, y in zip(val1, val2):
+ result |= ord(x) ^ ord(y)
+
+ return result == 0
+
+if has_crypto:
+
+ class RSAAlgorithm(Algorithm):
+ def __init__(self, hash_alg):
+ self.hash_alg = hash_alg
+
+ def prepare_key(self, key):
+ if isinstance(key, interfaces.RSAPrivateKey) or \
+ isinstance(key, interfaces.RSAPublicKey):
+ return key
+
+ if isinstance(key, basestring):
+ if isinstance(key, unicode):
+ key = key.encode('utf-8')
+
+ try:
+ if key.startswith(b'ssh-rsa'):
+ key = load_ssh_public_key(key, backend=default_backend())
+ else:
+ key = load_pem_private_key(key, password=None, backend=default_backend())
+ except ValueError:
+ key = load_pem_public_key(key, backend=default_backend())
+ else:
+ raise TypeError('Expecting a PEM-formatted key.')
+
+ return key
+
+ def sign(self, msg, key):
+ signer = key.signer(
+ padding.PKCS1v15(),
+ self.hash_alg
+ )
+
+ signer.update(msg)
+ return signer.finalize()
+
+ def verify(self, msg, key, sig):
+ verifier = key.verifier(
+ sig,
+ padding.PKCS1v15(),
+ self.hash_alg
+ )
+
+ verifier.update(msg)
+
+ try:
+ verifier.verify()
+ return True
+ except InvalidSignature:
+ return False
+
+ class ECAlgorithm(Algorithm):
+ def __init__(self, hash_alg):
+ self.hash_alg = hash_alg
+
+ def prepare_key(self, key):
+ if isinstance(key, interfaces.EllipticCurvePrivateKey) or \
+ isinstance(key, interfaces.EllipticCurvePublicKey):
+ return key
+
+ if isinstance(key, basestring):
+ if isinstance(key, unicode):
+ key = key.encode('utf-8')
+
+ # Attempt to load key. We don't know if it's
+ # a Signing Key or a Verifying Key, so we try
+ # the Verifying Key first.
+ try:
+ key = load_pem_public_key(key, backend=default_backend())
+ except ValueError:
+ key = load_pem_private_key(key, password=None, backend=default_backend())
+
+ else:
+ raise TypeError('Expecting a PEM-formatted key.')
+
+ return key
+
+ def sign(self, msg, key):
+ signer = key.signer(ec.ECDSA(self.hash_alg))
+
+ signer.update(msg)
+ return signer.finalize()
+
+ def verify(self, msg, key, sig):
+ verifier = key.verifier(sig, ec.ECDSA(self.hash_alg))
+
+ verifier.update(msg)
+
+ try:
+ verifier.verify()
+ return True
+ except InvalidSignature:
+ return False