summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHubert Kario <hkario@redhat.com>2021-07-15 21:32:15 +0200
committerHubert Kario <hkario@redhat.com>2021-07-20 15:19:53 +0200
commitb52ef334e2bb472d13500b232f520aed7a5ba8b2 (patch)
tree5cc4c3ea3798900684d5874141eccf8314d54fcf /src
parent34e9cec03ca75f71de3334131608512ec6f80fb6 (diff)
downloadecdsa-b52ef334e2bb472d13500b232f520aed7a5ba8b2.tar.gz
add SHAKE-256 implementation
On earlier pythons we don't have the ability to set the size of SHAKE-256 output, so we need to use our own implementation.
Diffstat (limited to 'src')
-rw-r--r--src/ecdsa/_compat.py79
-rw-r--r--src/ecdsa/_sha3.py182
-rw-r--r--src/ecdsa/test_sha3.py111
3 files changed, 367 insertions, 5 deletions
diff --git a/src/ecdsa/_compat.py b/src/ecdsa/_compat.py
index d773e75..5b31b12 100644
--- a/src/ecdsa/_compat.py
+++ b/src/ecdsa/_compat.py
@@ -15,6 +15,8 @@ def str_idx_as_int(string, index):
if sys.version_info < (3, 0): # pragma: no branch
+ import binascii
+ import platform
def normalise_bytes(buffer_object):
"""Cast the input into array of bytes."""
@@ -24,22 +26,64 @@ if sys.version_info < (3, 0): # pragma: no branch
def hmac_compat(ret):
return ret
- if sys.version_info < (2, 7) or sys.version_info < ( # pragma: no branch
- 2,
- 7,
- 4,
- ):
+ if (
+ sys.version_info < (2, 7)
+ or sys.version_info < (2, 7, 4)
+ or platform.system() == "Java"
+ ): # pragma: no branch
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text)
+ def compat26_str(val):
+ return str(val)
+
+ def bit_length(val):
+ if val == 0:
+ return 0
+ return len(bin(val)) - 2
+
else:
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text, flags=re.UNICODE)
+ def compat26_str(val):
+ return val
+
+ def bit_length(val):
+ """Return number of bits necessary to represent an integer."""
+ return val.bit_length()
+
+ def b2a_hex(val):
+ return binascii.b2a_hex(compat26_str(val))
+
+ def bytes_to_int(val, byteorder):
+ """Convert bytes to an int."""
+ if not val:
+ return 0
+ if byteorder == "big":
+ return int(b2a_hex(val), 16)
+ if byteorder == "little":
+ return int(b2a_hex(val[::-1]), 16)
+ raise ValueError("Only 'big' and 'little' endian supported")
+
+ def int_to_bytes(val, length=None, byteorder="big"):
+ """Return number converted to bytes"""
+ if length is None:
+ length = byte_length(val)
+ if byteorder == "big":
+ return bytearray(
+ (val >> i) & 0xFF for i in reversed(range(0, length * 8, 8))
+ )
+ if byteorder == "little":
+ return bytearray(
+ (val >> i) & 0xFF for i in range(0, length * 8, 8)
+ )
+ raise ValueError("Only 'big' or 'little' endian supported")
+
else:
if sys.version_info < (3, 4): # pragma: no branch
@@ -62,3 +106,28 @@ else:
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text, flags=re.UNICODE)
+
+ # pylint: disable=invalid-name
+ # pylint is stupid here and deson't notice it's a function, not
+ # constant
+ bytes_to_int = int.from_bytes
+ # pylint: enable=invalid-name
+
+ def bit_length(val):
+ """Return number of bits necessary to represent an integer."""
+ return val.bit_length()
+
+ def int_to_bytes(val, length=None, byteorder="big"):
+ """Convert integer to bytes."""
+ if length is None:
+ length = byte_length(val)
+ # for gmpy we need to convert back to native int
+ if type(val) != int:
+ val = int(val)
+ return bytearray(val.to_bytes(length=length, byteorder=byteorder))
+
+
+def byte_length(val):
+ """Return number of bytes necessary to represent an integer."""
+ length = bit_length(val)
+ return (length + 7) // 8
diff --git a/src/ecdsa/_sha3.py b/src/ecdsa/_sha3.py
new file mode 100644
index 0000000..78a1f1f
--- /dev/null
+++ b/src/ecdsa/_sha3.py
@@ -0,0 +1,182 @@
+"""
+Implementation of the SHAKE-256 algorithm for Ed448
+"""
+
+try:
+ import hashlib
+
+ hashlib.new("shake256").digest(64)
+
+ def shake_256(msg, outlen):
+ return hashlib.new("shake256", msg).digest(outlen)
+
+
+except (TypeError, ValueError):
+
+ from ._compat import bytes_to_int, int_to_bytes
+
+ # From little endian.
+ def _from_le(s):
+ return bytes_to_int(s, byteorder="little")
+
+ # Rotate a word x by b places to the left.
+ def _rol(x, b):
+ return ((x << b) | (x >> (64 - b))) & (2 ** 64 - 1)
+
+ # Do the SHA-3 state transform on state s.
+ def _sha3_transform(s):
+ ROTATIONS = [
+ 0,
+ 1,
+ 62,
+ 28,
+ 27,
+ 36,
+ 44,
+ 6,
+ 55,
+ 20,
+ 3,
+ 10,
+ 43,
+ 25,
+ 39,
+ 41,
+ 45,
+ 15,
+ 21,
+ 8,
+ 18,
+ 2,
+ 61,
+ 56,
+ 14,
+ ]
+ PERMUTATION = [
+ 1,
+ 6,
+ 9,
+ 22,
+ 14,
+ 20,
+ 2,
+ 12,
+ 13,
+ 19,
+ 23,
+ 15,
+ 4,
+ 24,
+ 21,
+ 8,
+ 16,
+ 5,
+ 3,
+ 18,
+ 17,
+ 11,
+ 7,
+ 10,
+ ]
+ RC = [
+ 0x0000000000000001,
+ 0x0000000000008082,
+ 0x800000000000808A,
+ 0x8000000080008000,
+ 0x000000000000808B,
+ 0x0000000080000001,
+ 0x8000000080008081,
+ 0x8000000000008009,
+ 0x000000000000008A,
+ 0x0000000000000088,
+ 0x0000000080008009,
+ 0x000000008000000A,
+ 0x000000008000808B,
+ 0x800000000000008B,
+ 0x8000000000008089,
+ 0x8000000000008003,
+ 0x8000000000008002,
+ 0x8000000000000080,
+ 0x000000000000800A,
+ 0x800000008000000A,
+ 0x8000000080008081,
+ 0x8000000000008080,
+ 0x0000000080000001,
+ 0x8000000080008008,
+ ]
+
+ for rnd in range(0, 24):
+ # AddColumnParity (Theta)
+ c = [0] * 5
+ d = [0] * 5
+ for i in range(0, 25):
+ c[i % 5] ^= s[i]
+ for i in range(0, 5):
+ d[i] = c[(i + 4) % 5] ^ _rol(c[(i + 1) % 5], 1)
+ for i in range(0, 25):
+ s[i] ^= d[i % 5]
+ # RotateWords (Rho)
+ for i in range(0, 25):
+ s[i] = _rol(s[i], ROTATIONS[i])
+ # PermuteWords (Pi)
+ t = s[PERMUTATION[0]]
+ for i in range(0, len(PERMUTATION) - 1):
+ s[PERMUTATION[i]] = s[PERMUTATION[i + 1]]
+ s[PERMUTATION[-1]] = t
+ # NonlinearMixRows (Chi)
+ for i in range(0, 25, 5):
+ t = [
+ s[i],
+ s[i + 1],
+ s[i + 2],
+ s[i + 3],
+ s[i + 4],
+ s[i],
+ s[i + 1],
+ ]
+ for j in range(0, 5):
+ s[i + j] = t[j] ^ ((~t[j + 1]) & (t[j + 2]))
+ # AddRoundConstant (Iota)
+ s[0] ^= RC[rnd]
+
+ # Reinterpret octet array b to word array and XOR it to state s.
+ def _reinterpret_to_words_and_xor(s, b):
+ for j in range(0, len(b) // 8):
+ s[j] ^= _from_le(b[8 * j : 8 * j + 8])
+
+ # Reinterpret word array w to octet array and return it.
+ def _reinterpret_to_octets(w):
+ mp = bytearray()
+ for j in range(0, len(w)):
+ mp += int_to_bytes(w[j], 8, byteorder="little")
+ return mp
+
+ def _sha3_raw(msg, r_w, o_p, e_b):
+ """Semi-generic SHA-3 implementation"""
+ r_b = 8 * r_w
+ s = [0] * 25
+ # Handle whole blocks.
+ idx = 0
+ blocks = len(msg) // r_b
+ for i in range(0, blocks):
+ _reinterpret_to_words_and_xor(s, msg[idx : idx + r_b])
+ idx += r_b
+ _sha3_transform(s)
+ # Handle last block padding.
+ m = bytearray(msg[idx:])
+ m.append(o_p)
+ while len(m) < r_b:
+ m.append(0)
+ m[len(m) - 1] |= 128
+ # Handle padded last block.
+ _reinterpret_to_words_and_xor(s, m)
+ _sha3_transform(s)
+ # Output.
+ out = bytearray()
+ while len(out) < e_b:
+ out += _reinterpret_to_octets(s[:r_w])
+ _sha3_transform(s)
+ return out[:e_b]
+
+ def shake_256(msg, outlen):
+ return _sha3_raw(msg, 17, 31, outlen)
diff --git a/src/ecdsa/test_sha3.py b/src/ecdsa/test_sha3.py
new file mode 100644
index 0000000..2c6bd15
--- /dev/null
+++ b/src/ecdsa/test_sha3.py
@@ -0,0 +1,111 @@
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import pytest
+
+try:
+ from gmpy2 import mpz
+
+ GMPY = True
+except ImportError:
+ try:
+ from gmpy import mpz
+
+ GMPY = True
+ except ImportError:
+ GMPY = False
+
+from ._sha3 import shake_256
+from ._compat import bytes_to_int, int_to_bytes
+
+B2I_VECTORS = [
+ (b"\x00\x01", "big", 1),
+ (b"\x00\x01", "little", 0x0100),
+ (b"", "big", 0),
+ (b"\x00", "little", 0),
+]
+
+
+@pytest.mark.parametrize("bytes_in,endian,int_out", B2I_VECTORS)
+def test_bytes_to_int(bytes_in, endian, int_out):
+ out = bytes_to_int(bytes_in, endian)
+ assert out == int_out
+
+
+class TestBytesToInt(unittest.TestCase):
+ def test_bytes_to_int_wrong_endian(self):
+ with self.assertRaises(ValueError):
+ bytes_to_int(b"\x00", "middle")
+
+ def test_int_to_bytes_wrong_endian(self):
+ with self.assertRaises(ValueError):
+ int_to_bytes(0, byteorder="middle")
+
+
+@pytest.mark.skipif(GMPY == False, reason="requites gmpy or gmpy2")
+def test_int_to_bytes_with_gmpy():
+ assert int_to_bytes(mpz(1)) == b"\x01"
+
+
+I2B_VECTORS = [
+ (0, None, "big", b""),
+ (0, 1, "big", b"\x00"),
+ (1, None, "big", b"\x01"),
+ (0x0100, None, "little", b"\x00\x01"),
+ (0x0100, 4, "little", b"\x00\x01\x00\x00"),
+ (1, 4, "big", b"\x00\x00\x00\x01"),
+]
+
+
+@pytest.mark.parametrize("int_in,length,endian,bytes_out", I2B_VECTORS)
+def test_int_to_bytes(int_in, length, endian, bytes_out):
+ out = int_to_bytes(int_in, length, endian)
+ assert out == bytes_out
+
+
+SHAKE_256_VECTORS = [
+ (
+ b"Message.",
+ 32,
+ b"\x78\xa1\x37\xbb\x33\xae\xe2\x72\xb1\x02\x4f\x39\x43\xe5\xcf\x0c"
+ b"\x4e\x9c\x72\x76\x2e\x34\x4c\xf8\xf9\xc3\x25\x9d\x4f\x91\x2c\x3a",
+ ),
+ (
+ b"",
+ 32,
+ b"\x46\xb9\xdd\x2b\x0b\xa8\x8d\x13\x23\x3b\x3f\xeb\x74\x3e\xeb\x24"
+ b"\x3f\xcd\x52\xea\x62\xb8\x1b\x82\xb5\x0c\x27\x64\x6e\xd5\x76\x2f",
+ ),
+ (
+ b"message",
+ 32,
+ b"\x86\x16\xe1\xe4\xcf\xd8\xb5\xf7\xd9\x2d\x43\xd8\x6e\x1b\x14\x51"
+ b"\xa2\xa6\x5a\xf8\x64\xfc\xb1\x26\xc2\x66\x0a\xb3\x46\x51\xb1\x75",
+ ),
+ (
+ b"message",
+ 16,
+ b"\x86\x16\xe1\xe4\xcf\xd8\xb5\xf7\xd9\x2d\x43\xd8\x6e\x1b\x14\x51",
+ ),
+ (
+ b"message",
+ 64,
+ b"\x86\x16\xe1\xe4\xcf\xd8\xb5\xf7\xd9\x2d\x43\xd8\x6e\x1b\x14\x51"
+ b"\xa2\xa6\x5a\xf8\x64\xfc\xb1\x26\xc2\x66\x0a\xb3\x46\x51\xb1\x75"
+ b"\x30\xd6\xba\x2a\x46\x65\xf1\x9d\xf0\x62\x25\xb1\x26\xd1\x3e\xed"
+ b"\x91\xd5\x0d\xe7\xb9\xcb\x65\xf3\x3a\x46\xae\xd3\x6c\x7d\xc5\xe8",
+ ),
+ (
+ b"A" * 1024,
+ 32,
+ b"\xa5\xef\x7e\x30\x8b\xe8\x33\x64\xe5\x9c\xf3\xb5\xf3\xba\x20\xa3"
+ b"\x5a\xe7\x30\xfd\xbc\x33\x11\xbf\x83\x89\x50\x82\xb4\x41\xe9\xb3",
+ ),
+]
+
+
+@pytest.mark.parametrize("msg,olen,ohash", SHAKE_256_VECTORS)
+def test_shake_256(msg, olen, ohash):
+ out = shake_256(msg, olen)
+ assert out == bytearray(ohash)