summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJon Dufresne <jon.dufresne@gmail.com>2020-12-20 09:26:36 -0800
committerGitHub <noreply@github.com>2020-12-20 12:26:36 -0500
commit5d7bdd68aec70a8a8cc4979bda93adc3c44f945f (patch)
tree42c0b6abe60baefd8d6ee9c28a0313b3f9d1102f
parent5bf8e70f26309b0a5ceda49ce0a7e97e89eee764 (diff)
downloadpyjwt-5d7bdd68aec70a8a8cc4979bda93adc3c44f945f.tar.gz
Type hint jwt.utils module (#564)
-rw-r--r--jwt/utils.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/jwt/utils.py b/jwt/utils.py
index f9579d3..6a09dcf 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -1,17 +1,19 @@
import base64
import binascii
import struct
+from typing import Any, Union
try:
+ from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
from cryptography.hazmat.primitives.asymmetric.utils import (
decode_dss_signature,
encode_dss_signature,
)
except ImportError:
- pass
+ EllipticCurve = Any # type: ignore
-def force_bytes(value):
+def force_bytes(value: Union[str, bytes]) -> bytes:
if isinstance(value, str):
return value.encode("utf-8")
elif isinstance(value, bytes):
@@ -20,7 +22,7 @@ def force_bytes(value):
raise TypeError("Expected a string value")
-def base64url_decode(input):
+def base64url_decode(input: Union[str, bytes]) -> bytes:
if isinstance(input, str):
input = input.encode("ascii")
@@ -36,7 +38,7 @@ def base64url_encode(input: bytes) -> bytes:
return base64.urlsafe_b64encode(input).replace(b"=", b"")
-def to_base64url_uint(val):
+def to_base64url_uint(val: int) -> bytes:
if val < 0:
raise ValueError("Must be a positive integer")
@@ -48,7 +50,7 @@ def to_base64url_uint(val):
return base64url_encode(int_bytes)
-def from_base64url_uint(val):
+def from_base64url_uint(val: Union[str, bytes]) -> int:
if isinstance(val, str):
val = val.encode("ascii")
@@ -58,17 +60,17 @@ def from_base64url_uint(val):
return int("".join(["%02x" % byte for byte in buf]), 16)
-def number_to_bytes(num, num_bytes):
+def number_to_bytes(num: int, num_bytes: int) -> bytes:
padded_hex = "%0*x" % (2 * num_bytes, num)
big_endian = binascii.a2b_hex(padded_hex.encode("ascii"))
return big_endian
-def bytes_to_number(string):
+def bytes_to_number(string: bytes) -> int:
return int(binascii.b2a_hex(string), 16)
-def bytes_from_int(val):
+def bytes_from_int(val: int) -> bytes:
remaining = val
byte_length = 0
@@ -79,7 +81,7 @@ def bytes_from_int(val):
return val.to_bytes(byte_length, "big", signed=False)
-def der_to_raw_signature(der_sig, curve):
+def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
@@ -88,7 +90,7 @@ def der_to_raw_signature(der_sig, curve):
return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
-def raw_to_der_signature(raw_sig, curve):
+def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8