diff options
author | Jon Dufresne <jon.dufresne@gmail.com> | 2020-06-18 22:53:17 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-19 11:53:17 +0600 |
commit | dc8dc7d05d54bd5502295601c01b557caab92a76 (patch) | |
tree | 72ff8cc4f0a65015dc7a114d321ee0df4a5da637 | |
parent | 07210eef3741781e0483cec656b99477c9185732 (diff) | |
download | pyjwt-dc8dc7d05d54bd5502295601c01b557caab92a76.tar.gz |
Remove unnecessary compatibility shims for Python 2 (#498)
As the project is Python 3 only, can remove the compatibility shims in
compat.py.
Type checking has been simplified where it can:
- str is iterable
- bytes is iterable
- use isinstance instead of issubclass
The remaining function bytes_from_int() has been moved to utils.py.
-rw-r--r-- | docs/api.rst | 2 | ||||
-rw-r--r-- | jwt/__main__.py | 9 | ||||
-rw-r--r-- | jwt/algorithms.py | 7 | ||||
-rw-r--r-- | jwt/api_jws.py | 10 | ||||
-rw-r--r-- | jwt/api_jwt.py | 12 | ||||
-rw-r--r-- | jwt/compat.py | 30 | ||||
-rw-r--r-- | jwt/contrib/algorithms/py_ecdsa.py | 5 | ||||
-rw-r--r-- | jwt/contrib/algorithms/py_ed25519.py | 5 | ||||
-rw-r--r-- | jwt/contrib/algorithms/pycrypto.py | 5 | ||||
-rw-r--r-- | jwt/utils.py | 25 | ||||
-rw-r--r-- | tests/test_api_jwt.py | 2 | ||||
-rw-r--r-- | tests/test_compat.py | 17 |
12 files changed, 46 insertions, 83 deletions
diff --git a/docs/api.rst b/docs/api.rst index 3f66173..6ca0b62 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -52,7 +52,7 @@ API Reference Use ``verify_exp`` instead - :param str|iterable audience: optional, the value for ``verify_aud`` check + :param iterable audience: optional, the value for ``verify_aud`` check :param str issuer: optional, the value for ``verify_iss`` check :param int|float leeway: a time margin in seconds for the expiration check :param bool verify_expiration: diff --git a/jwt/__main__.py b/jwt/__main__.py index 4fd3d33..9ed8b71 100644 --- a/jwt/__main__.py +++ b/jwt/__main__.py @@ -71,8 +71,13 @@ def decode_payload(args): raise OSError("Cannot read from stdin: terminal not a TTY") token = token.encode("utf-8") - data = decode(token, key=args.key, verify=args.verify, - audience=args.audience, issuer=args.issuer) + data = decode( + token, + key=args.key, + verify=args.verify, + audience=args.audience, + issuer=args.issuer, + ) return json.dumps(data) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 10ef680..3a7040c 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -2,7 +2,6 @@ import hashlib import hmac import json -from .compat import constant_time_compare, string_types from .exceptions import InvalidKeyError from .utils import ( base64url_decode, @@ -216,7 +215,7 @@ class HMACAlgorithm(Algorithm): return hmac.new(key, msg, self.hash_alg).digest() def verify(self, msg, key, sig): - return constant_time_compare(sig, self.sign(msg, key)) + return hmac.compare_digest(sig, self.sign(msg, key)) if has_crypto: # noqa: C901 @@ -238,7 +237,7 @@ if has_crypto: # noqa: C901 if isinstance(key, RSAPrivateKey) or isinstance(key, RSAPublicKey): return key - if isinstance(key, string_types): + if isinstance(key, (bytes, str)): key = force_bytes(key) try: @@ -395,7 +394,7 @@ if has_crypto: # noqa: C901 ): return key - if isinstance(key, string_types): + if isinstance(key, (bytes, str)): key = force_bytes(key) # Attempt to load key. We don't know if it's diff --git a/jwt/api_jws.py b/jwt/api_jws.py index ef5d7ec..287dbf4 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -1,10 +1,10 @@ import binascii import json import warnings +from collections.abc import Mapping from .algorithms import requires_cryptography # NOQA from .algorithms import Algorithm, get_default_algorithms, has_crypto -from .compat import Mapping, binary_type, string_types, text_type from .exceptions import ( DecodeError, InvalidAlgorithmError, @@ -178,12 +178,12 @@ class PyJWS: return headers def _load(self, jwt): - if isinstance(jwt, text_type): + if isinstance(jwt, str): jwt = jwt.encode("utf-8") - if not issubclass(type(jwt), binary_type): + if not isinstance(jwt, bytes): raise DecodeError( - "Invalid token type. Token must be a {}".format(binary_type) + "Invalid token type. Token must be a {}".format(bytes) ) try: @@ -249,7 +249,7 @@ class PyJWS: self._validate_kid(headers["kid"]) def _validate_kid(self, kid): - if not isinstance(kid, string_types): + if not isinstance(kid, (bytes, str)): raise InvalidTokenError("Key ID header parameter must be a string") diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 4289b40..b42a4d6 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -1,11 +1,11 @@ import json import warnings from calendar import timegm +from collections.abc import Iterable, Mapping from datetime import datetime, timedelta from .algorithms import Algorithm, get_default_algorithms # NOQA from .api_jws import PyJWS -from .compat import Iterable, Mapping, string_types from .exceptions import ( DecodeError, ExpiredSignatureError, @@ -149,8 +149,8 @@ class PyJWT(PyJWS): if isinstance(leeway, timedelta): leeway = leeway.total_seconds() - if not isinstance(audience, (string_types, type(None), Iterable)): - raise TypeError("audience must be a string, iterable, or None") + if not isinstance(audience, (type(None), Iterable)): + raise TypeError("audience must be an iterable or None") self._validate_required_claims(payload, options) @@ -220,14 +220,14 @@ class PyJWT(PyJWS): audience_claims = payload["aud"] - if isinstance(audience_claims, string_types): + if isinstance(audience_claims, (bytes, str)): audience_claims = [audience_claims] if not isinstance(audience_claims, list): raise InvalidAudienceError("Invalid claim format in token") - if any(not isinstance(c, string_types) for c in audience_claims): + if any(not isinstance(c, (bytes, str)) for c in audience_claims): raise InvalidAudienceError("Invalid claim format in token") - if isinstance(audience, string_types): + if isinstance(audience, (bytes, str)): audience = [audience] if not any(aud in audience_claims for aud in audience): diff --git a/jwt/compat.py b/jwt/compat.py deleted file mode 100644 index b4fec15..0000000 --- a/jwt/compat.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -The `compat` module provides support for backwards compatibility with older -versions of python, and compatibility wrappers around optional packages. -""" -# flake8: noqa -import hmac - -text_type = str -binary_type = bytes -string_types = (str, bytes) - -try: - # Importing ABCs from collections will be removed in PY3.8 - from collections.abc import Iterable, Mapping -except ImportError: - from collections import Iterable, Mapping - - -constant_time_compare = hmac.compare_digest - - -def bytes_from_int(val): - remaining = val - byte_length = 0 - - while remaining != 0: - remaining = remaining >> 8 - byte_length += 1 - - return val.to_bytes(byte_length, "big", signed=False) diff --git a/jwt/contrib/algorithms/py_ecdsa.py b/jwt/contrib/algorithms/py_ecdsa.py index 5b878f5..3f5fa76 100644 --- a/jwt/contrib/algorithms/py_ecdsa.py +++ b/jwt/contrib/algorithms/py_ecdsa.py @@ -6,7 +6,6 @@ import hashlib import ecdsa from jwt.algorithms import Algorithm -from jwt.compat import string_types, text_type class ECAlgorithm(Algorithm): @@ -33,8 +32,8 @@ class ECAlgorithm(Algorithm): ): return key - if isinstance(key, string_types): - if isinstance(key, text_type): + if isinstance(key, (bytes, str)): + if isinstance(key, str): key = key.encode("utf-8") # Attempt to load key. We don't know if it's diff --git a/jwt/contrib/algorithms/py_ed25519.py b/jwt/contrib/algorithms/py_ed25519.py index 1a1d4da..a86000a 100644 --- a/jwt/contrib/algorithms/py_ed25519.py +++ b/jwt/contrib/algorithms/py_ed25519.py @@ -15,7 +15,6 @@ from cryptography.hazmat.primitives.serialization import ( ) from jwt.algorithms import Algorithm -from jwt.compat import string_types, text_type class Ed25519Algorithm(Algorithm): @@ -33,8 +32,8 @@ class Ed25519Algorithm(Algorithm): if isinstance(key, (Ed25519PrivateKey, Ed25519PublicKey)): return key - if isinstance(key, string_types): - if isinstance(key, text_type): + if isinstance(key, (bytes, str)): + if isinstance(key, str): key = key.encode("utf-8") str_key = key.decode("utf-8") diff --git a/jwt/contrib/algorithms/pycrypto.py b/jwt/contrib/algorithms/pycrypto.py index d58e907..0fb5871 100644 --- a/jwt/contrib/algorithms/pycrypto.py +++ b/jwt/contrib/algorithms/pycrypto.py @@ -5,7 +5,6 @@ from Crypto.PublicKey import RSA from Crypto.Signature import PKCS1_v1_5 from jwt.algorithms import Algorithm -from jwt.compat import string_types, text_type class RSAAlgorithm(Algorithm): @@ -30,8 +29,8 @@ class RSAAlgorithm(Algorithm): if isinstance(key, RSA._RSAobj): return key - if isinstance(key, string_types): - if isinstance(key, text_type): + if isinstance(key, (bytes, str)): + if isinstance(key, str): key = key.encode("utf-8") key = RSA.importKey(key) diff --git a/jwt/utils.py b/jwt/utils.py index cc7f56c..0e8f210 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -2,8 +2,6 @@ import base64 import binascii import struct -from .compat import binary_type, bytes_from_int, text_type - try: from cryptography.hazmat.primitives.asymmetric.utils import ( decode_dss_signature, @@ -14,25 +12,25 @@ except ImportError: def force_unicode(value): - if isinstance(value, binary_type): + if isinstance(value, bytes): return value.decode("utf-8") - elif isinstance(value, text_type): + elif isinstance(value, str): return value else: raise TypeError("Expected a string value") def force_bytes(value): - if isinstance(value, text_type): + if isinstance(value, str): return value.encode("utf-8") - elif isinstance(value, binary_type): + elif isinstance(value, bytes): return value else: raise TypeError("Expected a string value") def base64url_decode(input): - if isinstance(input, text_type): + if isinstance(input, str): input = input.encode("ascii") rem = len(input) % 4 @@ -60,7 +58,7 @@ def to_base64url_uint(val): def from_base64url_uint(val): - if isinstance(val, text_type): + if isinstance(val, str): val = val.encode("ascii") data = base64url_decode(val) @@ -92,6 +90,17 @@ def bytes_to_number(string): return int(binascii.b2a_hex(string), 16) +def bytes_from_int(val): + remaining = val + byte_length = 0 + + while remaining != 0: + remaining = remaining >> 8 + byte_length += 1 + + return val.to_bytes(byte_length, "big", signed=False) + + def der_to_raw_signature(der_sig, curve): num_bits = curve.key_size num_bytes = (num_bits + 7) // 8 diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 829f519..50c748b 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -99,7 +99,7 @@ class TestJWT: jwt.decode(example_jwt, secret, audience=1) exception = context.value - assert str(exception) == "audience must be a string, iterable, or None" + assert str(exception) == "audience must be an iterable or None" def test_decode_with_nonlist_aud_claim_throws_exception(self, jwt): secret = "secret" diff --git a/tests/test_compat.py b/tests/test_compat.py deleted file mode 100644 index 7a7516c..0000000 --- a/tests/test_compat.py +++ /dev/null @@ -1,17 +0,0 @@ -from jwt.compat import constant_time_compare -from jwt.utils import force_bytes - - -class TestCompat: - def test_constant_time_compare_returns_true_if_same(self): - assert constant_time_compare(force_bytes("abc"), force_bytes("abc")) - - def test_constant_time_compare_returns_false_if_diff_lengths(self): - assert not constant_time_compare( - force_bytes("abc"), force_bytes("abcd") - ) - - def test_constant_time_compare_returns_false_if_totally_different(self): - assert not constant_time_compare( - force_bytes("abcd"), force_bytes("efgh") - ) |