summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosé Padilla <jpadilla@webapplicate.com>2015-01-18 13:28:18 -0400
committerJosé Padilla <jpadilla@webapplicate.com>2015-01-18 13:28:18 -0400
commit6290a1c8b5928d691f9fc91b733550b93b5f8cc7 (patch)
treeba7a79fe402317c00ace2eadeb5901ea463cb809
parent0afba10cf16834e154a59280de089c30de3d9a61 (diff)
parenteff1505b5df81540eb6c3c73b1d8c924fd1c2ca5 (diff)
downloadpyjwt-6290a1c8b5928d691f9fc91b733550b93b5f8cc7.tar.gz
Merge pull request #71 from mark-adams/algorithm-registry
Fixes #70. Implemented registry pattern for class-based algorithms
-rw-r--r--jwt/__init__.py225
-rw-r--r--jwt/algorithms.py200
-rw-r--r--jwt/utils.py14
-rwxr-xr-xsetup.py3
-rw-r--r--tests/test_jwt.py72
5 files changed, 265 insertions, 249 deletions
diff --git a/jwt/__init__.py b/jwt/__init__.py
index 3a70913..b9a9986 100644
--- a/jwt/__init__.py
+++ b/jwt/__init__.py
@@ -5,16 +5,15 @@ Minimum implementation based on this spec:
http://self-issued.info/docs/draft-jones-json-web-token-01.html
"""
-import base64
import binascii
-import hashlib
-import hmac
-from datetime import datetime, timedelta
+
from calendar import timegm
from collections import Mapping
+from datetime import datetime, timedelta
-from .compat import (json, string_types, text_type, constant_time_compare,
- timedelta_total_seconds)
+from jwt.utils import base64url_decode, base64url_encode
+
+from .compat import (json, string_types, text_type, timedelta_total_seconds)
__version__ = '0.4.1'
@@ -22,6 +21,7 @@ __all__ = [
# Functions
'encode',
'decode',
+ 'register_algorithm',
# Exceptions
'InvalidTokenError',
@@ -33,9 +33,25 @@ __all__ = [
# Deprecated aliases
'ExpiredSignature',
'InvalidAudience',
- 'InvalidIssuer',
+ 'InvalidIssuer'
]
+_algorithms = {}
+
+
+def register_algorithm(alg_id, alg_obj):
+ """ Registers a new Algorithm for use when creating and verifying JWTs """
+ if alg_id in _algorithms:
+ raise ValueError('Algorithm already has a handler.')
+
+ if not isinstance(alg_obj, Algorithm):
+ raise TypeError('Object is not of type `Algorithm`')
+
+ _algorithms[alg_id] = alg_obj
+
+from jwt.algorithms import Algorithm, _register_default_algorithms # NOQA
+_register_default_algorithms()
+
class InvalidTokenError(Exception):
pass
@@ -56,187 +72,11 @@ class InvalidAudienceError(InvalidTokenError):
class InvalidIssuerError(InvalidTokenError):
pass
-
# Compatibility aliases (deprecated)
ExpiredSignature = ExpiredSignatureError
InvalidAudience = InvalidAudienceError
InvalidIssuer = InvalidIssuerError
-signing_methods = {
- 'none': lambda msg, key: b'',
- 'HS256': lambda msg, key: hmac.new(key, msg, hashlib.sha256).digest(),
- 'HS384': lambda msg, key: hmac.new(key, msg, hashlib.sha384).digest(),
- 'HS512': lambda msg, key: hmac.new(key, msg, hashlib.sha512).digest()
-}
-
-verify_methods = {
- 'HS256': lambda msg, key: hmac.new(key, msg, hashlib.sha256).digest(),
- 'HS384': lambda msg, key: hmac.new(key, msg, hashlib.sha384).digest(),
- 'HS512': lambda msg, key: hmac.new(key, msg, hashlib.sha512).digest()
-}
-
-
-def prepare_HS_key(key):
- if not isinstance(key, string_types) and not isinstance(key, bytes):
- raise TypeError('Expecting a string- or bytes-formatted key.')
-
- if isinstance(key, text_type):
- key = key.encode('utf-8')
-
- return key
-
-prepare_key_methods = {
- 'none': lambda key: None,
- 'HS256': prepare_HS_key,
- 'HS384': prepare_HS_key,
- 'HS512': prepare_HS_key
-}
-
-try:
- from cryptography.hazmat.primitives import interfaces, hashes
- from cryptography.hazmat.primitives.serialization import (
- load_pem_private_key, load_pem_public_key, load_ssh_public_key
- )
- from cryptography.hazmat.primitives.asymmetric import ec, padding
- from cryptography.hazmat.backends import default_backend
- from cryptography.exceptions import InvalidSignature
-
- def sign_rsa(msg, key, hashalg):
- signer = key.signer(
- padding.PKCS1v15(),
- hashalg
- )
-
- signer.update(msg)
- return signer.finalize()
-
- def verify_rsa(msg, key, hashalg, sig):
- verifier = key.verifier(
- sig,
- padding.PKCS1v15(),
- hashalg
- )
-
- verifier.update(msg)
-
- try:
- verifier.verify()
- return True
- except InvalidSignature:
- return False
-
- signing_methods.update({
- 'RS256': lambda msg, key: sign_rsa(msg, key, hashes.SHA256()),
- 'RS384': lambda msg, key: sign_rsa(msg, key, hashes.SHA384()),
- 'RS512': lambda msg, key: sign_rsa(msg, key, hashes.SHA512())
- })
-
- verify_methods.update({
- 'RS256': lambda msg, key, sig: verify_rsa(msg, key, hashes.SHA256(), sig),
- 'RS384': lambda msg, key, sig: verify_rsa(msg, key, hashes.SHA384(), sig),
- 'RS512': lambda msg, key, sig: verify_rsa(msg, key, hashes.SHA512(), sig)
- })
-
- def prepare_RS_key(key):
- if isinstance(key, interfaces.RSAPrivateKey) or \
- isinstance(key, interfaces.RSAPublicKey):
- return key
-
- if isinstance(key, string_types):
- if isinstance(key, text_type):
- key = key.encode('utf-8')
-
- try:
- if key.startswith(b'ssh-rsa'):
- key = load_ssh_public_key(key, backend=default_backend())
- else:
- key = load_pem_private_key(key, password=None, backend=default_backend())
- except ValueError:
- key = load_pem_public_key(key, backend=default_backend())
- else:
- raise TypeError('Expecting a PEM-formatted key.')
-
- return key
-
- prepare_key_methods.update({
- 'RS256': prepare_RS_key,
- 'RS384': prepare_RS_key,
- 'RS512': prepare_RS_key
- })
-
- def sign_ecdsa(msg, key, hashalg):
- signer = key.signer(ec.ECDSA(hashalg))
-
- signer.update(msg)
- return signer.finalize()
-
- def verify_ecdsa(msg, key, hashalg, sig):
- verifier = key.verifier(sig, ec.ECDSA(hashalg))
-
- verifier.update(msg)
-
- try:
- verifier.verify()
- return True
- except InvalidSignature:
- return False
-
- signing_methods.update({
- 'ES256': lambda msg, key: sign_ecdsa(msg, key, hashes.SHA256()),
- 'ES384': lambda msg, key: sign_ecdsa(msg, key, hashes.SHA384()),
- 'ES512': lambda msg, key: sign_ecdsa(msg, key, hashes.SHA512()),
- })
-
- verify_methods.update({
- 'ES256': lambda msg, key, sig: verify_ecdsa(msg, key, hashes.SHA256(), sig),
- 'ES384': lambda msg, key, sig: verify_ecdsa(msg, key, hashes.SHA384(), sig),
- 'ES512': lambda msg, key, sig: verify_ecdsa(msg, key, hashes.SHA512(), sig),
- })
-
- def prepare_ES_key(key):
- if isinstance(key, interfaces.EllipticCurvePrivateKey) or \
- isinstance(key, interfaces.EllipticCurvePublicKey):
- return key
-
- if isinstance(key, string_types):
- if isinstance(key, text_type):
- key = key.encode('utf-8')
-
- # Attempt to load key. We don't know if it's
- # a Signing Key or a Verifying Key, so we try
- # the Verifying Key first.
- try:
- key = load_pem_public_key(key, backend=default_backend())
- except ValueError:
- key = load_pem_private_key(key, password=None, backend=default_backend())
-
- else:
- raise TypeError('Expecting a PEM-formatted key.')
-
- return key
-
- prepare_key_methods.update({
- 'ES256': prepare_ES_key,
- 'ES384': prepare_ES_key,
- 'ES512': prepare_ES_key
- })
-
-except ImportError:
- pass
-
-
-def base64url_decode(input):
- rem = len(input) % 4
-
- if rem > 0:
- input += b'=' * (4 - rem)
-
- return base64.urlsafe_b64decode(input)
-
-
-def base64url_encode(input):
- return base64.urlsafe_b64encode(input).replace(b'=', b'')
-
def header(jwt):
if isinstance(jwt, text_type):
@@ -290,8 +130,10 @@ def encode(payload, key, algorithm='HS256', headers=None, json_encoder=None):
# Segments
signing_input = b'.'.join(segments)
try:
- key = prepare_key_methods[algorithm](key)
- signature = signing_methods[algorithm](signing_input, key)
+ alg_obj = _algorithms[algorithm]
+ key = alg_obj.prepare_key(key)
+ signature = alg_obj.sign(signing_input, key)
+
except KeyError:
raise NotImplementedError('Algorithm not supported')
@@ -360,17 +202,12 @@ def verify_signature(payload, signing_input, header, signature, key='',
raise TypeError('audience must be a string or None')
try:
- algorithm = header['alg'].upper()
- key = prepare_key_methods[algorithm](key)
+ alg_obj = _algorithms[header['alg'].upper()]
+ key = alg_obj.prepare_key(key)
- if algorithm.startswith('HS'):
- expected = verify_methods[algorithm](signing_input, key)
+ if not alg_obj.verify(signing_input, key, signature):
+ raise DecodeError('Signature verification failed')
- if not constant_time_compare(signature, expected):
- raise DecodeError('Signature verification failed')
- else:
- if not verify_methods[algorithm](signing_input, key, signature):
- raise DecodeError('Signature verification failed')
except KeyError:
raise DecodeError('Algorithm not supported')
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
new file mode 100644
index 0000000..89ea75b
--- /dev/null
+++ b/jwt/algorithms.py
@@ -0,0 +1,200 @@
+import hashlib
+import hmac
+
+from jwt import register_algorithm
+from jwt.compat import constant_time_compare, string_types, text_type
+
+try:
+ from cryptography.hazmat.primitives import interfaces, hashes
+ from cryptography.hazmat.primitives.serialization import (
+ load_pem_private_key, load_pem_public_key, load_ssh_public_key
+ )
+ from cryptography.hazmat.primitives.asymmetric import ec, padding
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.exceptions import InvalidSignature
+
+ has_crypto = True
+except ImportError:
+ has_crypto = False
+
+
+def _register_default_algorithms():
+ """ Registers the algorithms that are implemented by the library """
+ register_algorithm('none', NoneAlgorithm())
+ register_algorithm('HS256', HMACAlgorithm(hashlib.sha256))
+ register_algorithm('HS384', HMACAlgorithm(hashlib.sha384))
+ register_algorithm('HS512', HMACAlgorithm(hashlib.sha512))
+
+ if has_crypto:
+ register_algorithm('RS256', RSAAlgorithm(hashes.SHA256()))
+ register_algorithm('RS384', RSAAlgorithm(hashes.SHA384()))
+ register_algorithm('RS512', RSAAlgorithm(hashes.SHA512()))
+
+ register_algorithm('ES256', ECAlgorithm(hashes.SHA256()))
+ register_algorithm('ES384', ECAlgorithm(hashes.SHA384()))
+ register_algorithm('ES512', ECAlgorithm(hashes.SHA512()))
+
+
+class Algorithm(object):
+ """ The interface for an algorithm used to sign and verify JWTs """
+ def prepare_key(self, key):
+ """
+ Performs necessary validation and conversions on the key and returns
+ the key value in the proper format for sign() and verify()
+ """
+ raise NotImplementedError
+
+ def sign(self, msg, key):
+ """
+ Returns a digital signature for the specified message using the
+ specified key value
+ """
+ raise NotImplementedError
+
+ def verify(self, msg, key, sig):
+ """
+ Verifies that the specified digital signature is valid for the specified
+ message and key values.
+ """
+ raise NotImplementedError
+
+
+class NoneAlgorithm(Algorithm):
+ """
+ Placeholder for use when no signing or verification operations are required
+ """
+ def prepare_key(self, key):
+ return None
+
+ def sign(self, msg, key):
+ return b''
+
+ def verify(self, msg, key):
+ return True
+
+
+class HMACAlgorithm(Algorithm):
+ """
+ Performs signing and verification operations using HMAC and the specified
+ hash function
+ """
+ def __init__(self, hash_alg):
+ self.hash_alg = hash_alg
+
+ def prepare_key(self, key):
+ if not isinstance(key, string_types) and not isinstance(key, bytes):
+ raise TypeError('Expecting a string- or bytes-formatted key.')
+
+ if isinstance(key, text_type):
+ key = key.encode('utf-8')
+
+ return key
+
+ def sign(self, msg, key):
+ return hmac.new(key, msg, self.hash_alg).digest()
+
+ def verify(self, msg, key, sig):
+ return constant_time_compare(sig, self.sign(msg, key))
+
+if has_crypto:
+
+ class RSAAlgorithm(Algorithm):
+ """
+ Performs signing and verification operations using RSASSA-PKCS-v1_5 and
+ the specified hash function
+ """
+
+ def __init__(self, hash_alg):
+ self.hash_alg = hash_alg
+
+ def prepare_key(self, key):
+ if isinstance(key, interfaces.RSAPrivateKey) or \
+ isinstance(key, interfaces.RSAPublicKey):
+ return key
+
+ if isinstance(key, string_types):
+ if isinstance(key, text_type):
+ key = key.encode('utf-8')
+
+ try:
+ if key.startswith(b'ssh-rsa'):
+ key = load_ssh_public_key(key, backend=default_backend())
+ else:
+ key = load_pem_private_key(key, password=None, backend=default_backend())
+ except ValueError:
+ key = load_pem_public_key(key, backend=default_backend())
+ else:
+ raise TypeError('Expecting a PEM-formatted key.')
+
+ return key
+
+ def sign(self, msg, key):
+ signer = key.signer(
+ padding.PKCS1v15(),
+ self.hash_alg
+ )
+
+ signer.update(msg)
+ return signer.finalize()
+
+ def verify(self, msg, key, sig):
+ verifier = key.verifier(
+ sig,
+ padding.PKCS1v15(),
+ self.hash_alg
+ )
+
+ verifier.update(msg)
+
+ try:
+ verifier.verify()
+ return True
+ except InvalidSignature:
+ return False
+
+ class ECAlgorithm(Algorithm):
+ """
+ Performs signing and verification operations using ECDSA and the
+ specified hash function
+ """
+ def __init__(self, hash_alg):
+ self.hash_alg = hash_alg
+
+ def prepare_key(self, key):
+ if isinstance(key, interfaces.EllipticCurvePrivateKey) or \
+ isinstance(key, interfaces.EllipticCurvePublicKey):
+ return key
+
+ if isinstance(key, string_types):
+ if isinstance(key, text_type):
+ key = key.encode('utf-8')
+
+ # Attempt to load key. We don't know if it's
+ # a Signing Key or a Verifying Key, so we try
+ # the Verifying Key first.
+ try:
+ key = load_pem_public_key(key, backend=default_backend())
+ except ValueError:
+ key = load_pem_private_key(key, password=None, backend=default_backend())
+
+ else:
+ raise TypeError('Expecting a PEM-formatted key.')
+
+ return key
+
+ def sign(self, msg, key):
+ signer = key.signer(ec.ECDSA(self.hash_alg))
+
+ signer.update(msg)
+ return signer.finalize()
+
+ def verify(self, msg, key, sig):
+ verifier = key.verifier(sig, ec.ECDSA(self.hash_alg))
+
+ verifier.update(msg)
+
+ try:
+ verifier.verify()
+ return True
+ except InvalidSignature:
+ return False
diff --git a/jwt/utils.py b/jwt/utils.py
new file mode 100644
index 0000000..e6c1ef3
--- /dev/null
+++ b/jwt/utils.py
@@ -0,0 +1,14 @@
+import base64
+
+
+def base64url_decode(input):
+ rem = len(input) % 4
+
+ if rem > 0:
+ input += b'=' * (4 - rem)
+
+ return base64.urlsafe_b64decode(input)
+
+
+def base64url_encode(input):
+ return base64.urlsafe_b64encode(input).replace(b'=', b'')
diff --git a/setup.py b/setup.py
index 62d5df7..e703db6 100755
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
-import sys
import re
+import sys
+
from setuptools import setup
diff --git a/tests/test_jwt.py b/tests/test_jwt.py
index a57ab31..bd9ca06 100644
--- a/tests/test_jwt.py
+++ b/tests/test_jwt.py
@@ -45,6 +45,10 @@ class TestJWT(unittest.TestCase):
self.payload = {'iss': 'jeff', 'exp': utc_timestamp() + 15,
'claim': 'insanity'}
+ def test_register_algorithm_rejects_non_algorithm_obj(self):
+ with self.assertRaises(TypeError):
+ jwt.register_algorithm('AAA123', {})
+
def test_encode_decode(self):
secret = 'secret'
jwt_message = jwt.encode(self.payload, secret)
@@ -549,35 +553,15 @@ class TestJWT(unittest.TestCase):
load_output = jwt.load(jwt_message)
jwt.verify_signature(key=pub_rsakey, *load_output)
- def test_rsa_related_signing_methods(self):
- if has_crypto:
- self.assertTrue('RS256' in jwt.signing_methods)
- self.assertTrue('RS384' in jwt.signing_methods)
- self.assertTrue('RS512' in jwt.signing_methods)
- else:
- self.assertFalse('RS256' in jwt.signing_methods)
- self.assertFalse('RS384' in jwt.signing_methods)
- self.assertFalse('RS512' in jwt.signing_methods)
-
- def test_rsa_related_verify_methods(self):
- if has_crypto:
- self.assertTrue('RS256' in jwt.verify_methods)
- self.assertTrue('RS384' in jwt.verify_methods)
- self.assertTrue('RS512' in jwt.verify_methods)
- else:
- self.assertFalse('RS256' in jwt.verify_methods)
- self.assertFalse('RS384' in jwt.verify_methods)
- self.assertFalse('RS512' in jwt.verify_methods)
-
- def test_rsa_related_key_preparation_methods(self):
+ def test_rsa_related_algorithms(self):
if has_crypto:
- self.assertTrue('RS256' in jwt.prepare_key_methods)
- self.assertTrue('RS384' in jwt.prepare_key_methods)
- self.assertTrue('RS512' in jwt.prepare_key_methods)
+ self.assertTrue('RS256' in jwt._algorithms)
+ self.assertTrue('RS384' in jwt._algorithms)
+ self.assertTrue('RS512' in jwt._algorithms)
else:
- self.assertFalse('RS256' in jwt.prepare_key_methods)
- self.assertFalse('RS384' in jwt.prepare_key_methods)
- self.assertFalse('RS512' in jwt.prepare_key_methods)
+ self.assertFalse('RS256' in jwt._algorithms)
+ self.assertFalse('RS384' in jwt._algorithms)
+ self.assertFalse('RS512' in jwt._algorithms)
@unittest.skipIf(not has_crypto, "Can't run without cryptography library")
def test_encode_decode_with_ecdsa_sha256(self):
@@ -669,35 +653,15 @@ class TestJWT(unittest.TestCase):
load_output = jwt.load(jwt_message)
jwt.verify_signature(key=pub_eckey, *load_output)
- def test_ecdsa_related_signing_methods(self):
- if has_crypto:
- self.assertTrue('ES256' in jwt.signing_methods)
- self.assertTrue('ES384' in jwt.signing_methods)
- self.assertTrue('ES512' in jwt.signing_methods)
- else:
- self.assertFalse('ES256' in jwt.signing_methods)
- self.assertFalse('ES384' in jwt.signing_methods)
- self.assertFalse('ES512' in jwt.signing_methods)
-
- def test_ecdsa_related_verify_methods(self):
- if has_crypto:
- self.assertTrue('ES256' in jwt.verify_methods)
- self.assertTrue('ES384' in jwt.verify_methods)
- self.assertTrue('ES512' in jwt.verify_methods)
- else:
- self.assertFalse('ES256' in jwt.verify_methods)
- self.assertFalse('ES384' in jwt.verify_methods)
- self.assertFalse('ES512' in jwt.verify_methods)
-
- def test_ecdsa_related_key_preparation_methods(self):
+ def test_ecdsa_related_algorithms(self):
if has_crypto:
- self.assertTrue('ES256' in jwt.prepare_key_methods)
- self.assertTrue('ES384' in jwt.prepare_key_methods)
- self.assertTrue('ES512' in jwt.prepare_key_methods)
+ self.assertTrue('ES256' in jwt._algorithms)
+ self.assertTrue('ES384' in jwt._algorithms)
+ self.assertTrue('ES512' in jwt._algorithms)
else:
- self.assertFalse('ES256' in jwt.prepare_key_methods)
- self.assertFalse('ES384' in jwt.prepare_key_methods)
- self.assertFalse('ES512' in jwt.prepare_key_methods)
+ self.assertFalse('ES256' in jwt._algorithms)
+ self.assertFalse('ES384' in jwt._algorithms)
+ self.assertFalse('ES512' in jwt._algorithms)
def test_check_audience(self):
payload = {