summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJon Dufresne <jon.dufresne@gmail.com>2020-06-18 22:53:17 -0700
committerGitHub <noreply@github.com>2020-06-19 11:53:17 +0600
commitdc8dc7d05d54bd5502295601c01b557caab92a76 (patch)
tree72ff8cc4f0a65015dc7a114d321ee0df4a5da637
parent07210eef3741781e0483cec656b99477c9185732 (diff)
downloadpyjwt-dc8dc7d05d54bd5502295601c01b557caab92a76.tar.gz
Remove unnecessary compatibility shims for Python 2 (#498)
As the project is Python 3 only, can remove the compatibility shims in compat.py. Type checking has been simplified where it can: - str is iterable - bytes is iterable - use isinstance instead of issubclass The remaining function bytes_from_int() has been moved to utils.py.
-rw-r--r--docs/api.rst2
-rw-r--r--jwt/__main__.py9
-rw-r--r--jwt/algorithms.py7
-rw-r--r--jwt/api_jws.py10
-rw-r--r--jwt/api_jwt.py12
-rw-r--r--jwt/compat.py30
-rw-r--r--jwt/contrib/algorithms/py_ecdsa.py5
-rw-r--r--jwt/contrib/algorithms/py_ed25519.py5
-rw-r--r--jwt/contrib/algorithms/pycrypto.py5
-rw-r--r--jwt/utils.py25
-rw-r--r--tests/test_api_jwt.py2
-rw-r--r--tests/test_compat.py17
12 files changed, 46 insertions, 83 deletions
diff --git a/docs/api.rst b/docs/api.rst
index 3f66173..6ca0b62 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -52,7 +52,7 @@ API Reference
Use ``verify_exp`` instead
- :param str|iterable audience: optional, the value for ``verify_aud`` check
+ :param iterable audience: optional, the value for ``verify_aud`` check
:param str issuer: optional, the value for ``verify_iss`` check
:param int|float leeway: a time margin in seconds for the expiration check
:param bool verify_expiration:
diff --git a/jwt/__main__.py b/jwt/__main__.py
index 4fd3d33..9ed8b71 100644
--- a/jwt/__main__.py
+++ b/jwt/__main__.py
@@ -71,8 +71,13 @@ def decode_payload(args):
raise OSError("Cannot read from stdin: terminal not a TTY")
token = token.encode("utf-8")
- data = decode(token, key=args.key, verify=args.verify,
- audience=args.audience, issuer=args.issuer)
+ data = decode(
+ token,
+ key=args.key,
+ verify=args.verify,
+ audience=args.audience,
+ issuer=args.issuer,
+ )
return json.dumps(data)
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 10ef680..3a7040c 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -2,7 +2,6 @@ import hashlib
import hmac
import json
-from .compat import constant_time_compare, string_types
from .exceptions import InvalidKeyError
from .utils import (
base64url_decode,
@@ -216,7 +215,7 @@ class HMACAlgorithm(Algorithm):
return hmac.new(key, msg, self.hash_alg).digest()
def verify(self, msg, key, sig):
- return constant_time_compare(sig, self.sign(msg, key))
+ return hmac.compare_digest(sig, self.sign(msg, key))
if has_crypto: # noqa: C901
@@ -238,7 +237,7 @@ if has_crypto: # noqa: C901
if isinstance(key, RSAPrivateKey) or isinstance(key, RSAPublicKey):
return key
- if isinstance(key, string_types):
+ if isinstance(key, (bytes, str)):
key = force_bytes(key)
try:
@@ -395,7 +394,7 @@ if has_crypto: # noqa: C901
):
return key
- if isinstance(key, string_types):
+ if isinstance(key, (bytes, str)):
key = force_bytes(key)
# Attempt to load key. We don't know if it's
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index ef5d7ec..287dbf4 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -1,10 +1,10 @@
import binascii
import json
import warnings
+from collections.abc import Mapping
from .algorithms import requires_cryptography # NOQA
from .algorithms import Algorithm, get_default_algorithms, has_crypto
-from .compat import Mapping, binary_type, string_types, text_type
from .exceptions import (
DecodeError,
InvalidAlgorithmError,
@@ -178,12 +178,12 @@ class PyJWS:
return headers
def _load(self, jwt):
- if isinstance(jwt, text_type):
+ if isinstance(jwt, str):
jwt = jwt.encode("utf-8")
- if not issubclass(type(jwt), binary_type):
+ if not isinstance(jwt, bytes):
raise DecodeError(
- "Invalid token type. Token must be a {}".format(binary_type)
+ "Invalid token type. Token must be a {}".format(bytes)
)
try:
@@ -249,7 +249,7 @@ class PyJWS:
self._validate_kid(headers["kid"])
def _validate_kid(self, kid):
- if not isinstance(kid, string_types):
+ if not isinstance(kid, (bytes, str)):
raise InvalidTokenError("Key ID header parameter must be a string")
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 4289b40..b42a4d6 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -1,11 +1,11 @@
import json
import warnings
from calendar import timegm
+from collections.abc import Iterable, Mapping
from datetime import datetime, timedelta
from .algorithms import Algorithm, get_default_algorithms # NOQA
from .api_jws import PyJWS
-from .compat import Iterable, Mapping, string_types
from .exceptions import (
DecodeError,
ExpiredSignatureError,
@@ -149,8 +149,8 @@ class PyJWT(PyJWS):
if isinstance(leeway, timedelta):
leeway = leeway.total_seconds()
- if not isinstance(audience, (string_types, type(None), Iterable)):
- raise TypeError("audience must be a string, iterable, or None")
+ if not isinstance(audience, (type(None), Iterable)):
+ raise TypeError("audience must be an iterable or None")
self._validate_required_claims(payload, options)
@@ -220,14 +220,14 @@ class PyJWT(PyJWS):
audience_claims = payload["aud"]
- if isinstance(audience_claims, string_types):
+ if isinstance(audience_claims, (bytes, str)):
audience_claims = [audience_claims]
if not isinstance(audience_claims, list):
raise InvalidAudienceError("Invalid claim format in token")
- if any(not isinstance(c, string_types) for c in audience_claims):
+ if any(not isinstance(c, (bytes, str)) for c in audience_claims):
raise InvalidAudienceError("Invalid claim format in token")
- if isinstance(audience, string_types):
+ if isinstance(audience, (bytes, str)):
audience = [audience]
if not any(aud in audience_claims for aud in audience):
diff --git a/jwt/compat.py b/jwt/compat.py
deleted file mode 100644
index b4fec15..0000000
--- a/jwt/compat.py
+++ /dev/null
@@ -1,30 +0,0 @@
-"""
-The `compat` module provides support for backwards compatibility with older
-versions of python, and compatibility wrappers around optional packages.
-"""
-# flake8: noqa
-import hmac
-
-text_type = str
-binary_type = bytes
-string_types = (str, bytes)
-
-try:
- # Importing ABCs from collections will be removed in PY3.8
- from collections.abc import Iterable, Mapping
-except ImportError:
- from collections import Iterable, Mapping
-
-
-constant_time_compare = hmac.compare_digest
-
-
-def bytes_from_int(val):
- remaining = val
- byte_length = 0
-
- while remaining != 0:
- remaining = remaining >> 8
- byte_length += 1
-
- return val.to_bytes(byte_length, "big", signed=False)
diff --git a/jwt/contrib/algorithms/py_ecdsa.py b/jwt/contrib/algorithms/py_ecdsa.py
index 5b878f5..3f5fa76 100644
--- a/jwt/contrib/algorithms/py_ecdsa.py
+++ b/jwt/contrib/algorithms/py_ecdsa.py
@@ -6,7 +6,6 @@ import hashlib
import ecdsa
from jwt.algorithms import Algorithm
-from jwt.compat import string_types, text_type
class ECAlgorithm(Algorithm):
@@ -33,8 +32,8 @@ class ECAlgorithm(Algorithm):
):
return key
- if isinstance(key, string_types):
- if isinstance(key, text_type):
+ if isinstance(key, (bytes, str)):
+ if isinstance(key, str):
key = key.encode("utf-8")
# Attempt to load key. We don't know if it's
diff --git a/jwt/contrib/algorithms/py_ed25519.py b/jwt/contrib/algorithms/py_ed25519.py
index 1a1d4da..a86000a 100644
--- a/jwt/contrib/algorithms/py_ed25519.py
+++ b/jwt/contrib/algorithms/py_ed25519.py
@@ -15,7 +15,6 @@ from cryptography.hazmat.primitives.serialization import (
)
from jwt.algorithms import Algorithm
-from jwt.compat import string_types, text_type
class Ed25519Algorithm(Algorithm):
@@ -33,8 +32,8 @@ class Ed25519Algorithm(Algorithm):
if isinstance(key, (Ed25519PrivateKey, Ed25519PublicKey)):
return key
- if isinstance(key, string_types):
- if isinstance(key, text_type):
+ if isinstance(key, (bytes, str)):
+ if isinstance(key, str):
key = key.encode("utf-8")
str_key = key.decode("utf-8")
diff --git a/jwt/contrib/algorithms/pycrypto.py b/jwt/contrib/algorithms/pycrypto.py
index d58e907..0fb5871 100644
--- a/jwt/contrib/algorithms/pycrypto.py
+++ b/jwt/contrib/algorithms/pycrypto.py
@@ -5,7 +5,6 @@ from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5
from jwt.algorithms import Algorithm
-from jwt.compat import string_types, text_type
class RSAAlgorithm(Algorithm):
@@ -30,8 +29,8 @@ class RSAAlgorithm(Algorithm):
if isinstance(key, RSA._RSAobj):
return key
- if isinstance(key, string_types):
- if isinstance(key, text_type):
+ if isinstance(key, (bytes, str)):
+ if isinstance(key, str):
key = key.encode("utf-8")
key = RSA.importKey(key)
diff --git a/jwt/utils.py b/jwt/utils.py
index cc7f56c..0e8f210 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -2,8 +2,6 @@ import base64
import binascii
import struct
-from .compat import binary_type, bytes_from_int, text_type
-
try:
from cryptography.hazmat.primitives.asymmetric.utils import (
decode_dss_signature,
@@ -14,25 +12,25 @@ except ImportError:
def force_unicode(value):
- if isinstance(value, binary_type):
+ if isinstance(value, bytes):
return value.decode("utf-8")
- elif isinstance(value, text_type):
+ elif isinstance(value, str):
return value
else:
raise TypeError("Expected a string value")
def force_bytes(value):
- if isinstance(value, text_type):
+ if isinstance(value, str):
return value.encode("utf-8")
- elif isinstance(value, binary_type):
+ elif isinstance(value, bytes):
return value
else:
raise TypeError("Expected a string value")
def base64url_decode(input):
- if isinstance(input, text_type):
+ if isinstance(input, str):
input = input.encode("ascii")
rem = len(input) % 4
@@ -60,7 +58,7 @@ def to_base64url_uint(val):
def from_base64url_uint(val):
- if isinstance(val, text_type):
+ if isinstance(val, str):
val = val.encode("ascii")
data = base64url_decode(val)
@@ -92,6 +90,17 @@ def bytes_to_number(string):
return int(binascii.b2a_hex(string), 16)
+def bytes_from_int(val):
+ remaining = val
+ byte_length = 0
+
+ while remaining != 0:
+ remaining = remaining >> 8
+ byte_length += 1
+
+ return val.to_bytes(byte_length, "big", signed=False)
+
+
def der_to_raw_signature(der_sig, curve):
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py
index 829f519..50c748b 100644
--- a/tests/test_api_jwt.py
+++ b/tests/test_api_jwt.py
@@ -99,7 +99,7 @@ class TestJWT:
jwt.decode(example_jwt, secret, audience=1)
exception = context.value
- assert str(exception) == "audience must be a string, iterable, or None"
+ assert str(exception) == "audience must be an iterable or None"
def test_decode_with_nonlist_aud_claim_throws_exception(self, jwt):
secret = "secret"
diff --git a/tests/test_compat.py b/tests/test_compat.py
deleted file mode 100644
index 7a7516c..0000000
--- a/tests/test_compat.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from jwt.compat import constant_time_compare
-from jwt.utils import force_bytes
-
-
-class TestCompat:
- def test_constant_time_compare_returns_true_if_same(self):
- assert constant_time_compare(force_bytes("abc"), force_bytes("abc"))
-
- def test_constant_time_compare_returns_false_if_diff_lengths(self):
- assert not constant_time_compare(
- force_bytes("abc"), force_bytes("abcd")
- )
-
- def test_constant_time_compare_returns_false_if_totally_different(self):
- assert not constant_time_compare(
- force_bytes("abcd"), force_bytes("efgh")
- )