diff options
Diffstat (limited to 'jwt/utils.py')
-rw-r--r-- | jwt/utils.py | 30 |
1 files changed, 13 insertions, 17 deletions
diff --git a/jwt/utils.py b/jwt/utils.py index 16cae06..92e3c8b 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -1,7 +1,7 @@ import base64 import binascii import re -from typing import Union +from typing import Any, AnyStr try: from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve @@ -10,10 +10,10 @@ try: encode_dss_signature, ) except ModuleNotFoundError: - EllipticCurve = None + pass -def force_bytes(value: Union[str, bytes]) -> bytes: +def force_bytes(value: Any) -> bytes: if isinstance(value, str): return value.encode("utf-8") elif isinstance(value, bytes): @@ -22,16 +22,15 @@ def force_bytes(value: Union[str, bytes]) -> bytes: raise TypeError("Expected a string value") -def base64url_decode(input: Union[str, bytes]) -> bytes: - if isinstance(input, str): - input = input.encode("ascii") +def base64url_decode(input: AnyStr) -> bytes: + input_bytes = force_bytes(input) - rem = len(input) % 4 + rem = len(input_bytes) % 4 if rem > 0: - input += b"=" * (4 - rem) + input_bytes += b"=" * (4 - rem) - return base64.urlsafe_b64decode(input) + return base64.urlsafe_b64decode(input_bytes) def base64url_encode(input: bytes) -> bytes: @@ -50,11 +49,8 @@ def to_base64url_uint(val: int) -> bytes: return base64url_encode(int_bytes) -def from_base64url_uint(val: Union[str, bytes]) -> int: - if isinstance(val, str): - val = val.encode("ascii") - - data = base64url_decode(val) +def from_base64url_uint(val: AnyStr) -> int: + data = base64url_decode(force_bytes(val)) return int.from_bytes(data, byteorder="big") @@ -78,7 +74,7 @@ def bytes_from_int(val: int) -> bytes: return val.to_bytes(byte_length, "big", signed=False) -def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes: +def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes: num_bits = curve.key_size num_bytes = (num_bits + 7) // 8 @@ -87,7 +83,7 @@ def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes: return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes) -def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes: +def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes: num_bits = curve.key_size num_bytes = (num_bits + 7) // 8 @@ -97,7 +93,7 @@ def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes: r = bytes_to_number(raw_sig[:num_bytes]) s = bytes_to_number(raw_sig[num_bytes:]) - return encode_dss_signature(r, s) + return bytes(encode_dss_signature(r, s)) # Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252 |