summaryrefslogtreecommitdiff
path: root/jwt/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwt/utils.py')
-rw-r--r--jwt/utils.py30
1 files changed, 13 insertions, 17 deletions
diff --git a/jwt/utils.py b/jwt/utils.py
index 16cae06..92e3c8b 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -1,7 +1,7 @@
import base64
import binascii
import re
-from typing import Union
+from typing import Any, AnyStr
try:
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
@@ -10,10 +10,10 @@ try:
encode_dss_signature,
)
except ModuleNotFoundError:
- EllipticCurve = None
+ pass
-def force_bytes(value: Union[str, bytes]) -> bytes:
+def force_bytes(value: Any) -> bytes:
if isinstance(value, str):
return value.encode("utf-8")
elif isinstance(value, bytes):
@@ -22,16 +22,15 @@ def force_bytes(value: Union[str, bytes]) -> bytes:
raise TypeError("Expected a string value")
-def base64url_decode(input: Union[str, bytes]) -> bytes:
- if isinstance(input, str):
- input = input.encode("ascii")
+def base64url_decode(input: AnyStr) -> bytes:
+ input_bytes = force_bytes(input)
- rem = len(input) % 4
+ rem = len(input_bytes) % 4
if rem > 0:
- input += b"=" * (4 - rem)
+ input_bytes += b"=" * (4 - rem)
- return base64.urlsafe_b64decode(input)
+ return base64.urlsafe_b64decode(input_bytes)
def base64url_encode(input: bytes) -> bytes:
@@ -50,11 +49,8 @@ def to_base64url_uint(val: int) -> bytes:
return base64url_encode(int_bytes)
-def from_base64url_uint(val: Union[str, bytes]) -> int:
- if isinstance(val, str):
- val = val.encode("ascii")
-
- data = base64url_decode(val)
+def from_base64url_uint(val: AnyStr) -> int:
+ data = base64url_decode(force_bytes(val))
return int.from_bytes(data, byteorder="big")
@@ -78,7 +74,7 @@ def bytes_from_int(val: int) -> bytes:
return val.to_bytes(byte_length, "big", signed=False)
-def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
+def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes:
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
@@ -87,7 +83,7 @@ def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
-def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
+def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes:
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
@@ -97,7 +93,7 @@ def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
r = bytes_to_number(raw_sig[:num_bytes])
s = bytes_to_number(raw_sig[num_bytes:])
- return encode_dss_signature(r, s)
+ return bytes(encode_dss_signature(r, s))
# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252