diff options
-rw-r--r-- | rsa/_compat.py | 21 | ||||
-rw-r--r-- | tests/test_compat.py | 47 |
2 files changed, 67 insertions, 1 deletions
diff --git a/rsa/_compat.py b/rsa/_compat.py index 1e51368..38bab08 100644 --- a/rsa/_compat.py +++ b/rsa/_compat.py @@ -99,6 +99,27 @@ def byte(num): return pack("B", num) +def xor_bytes(b1, b2): + """ + Returns the bitwise XOR result between two bytes objects, b1 ^ b2. + + Bitwise XOR operation is commutative, so order of parameters doesn't + generate different results. If parameters have different length, extra + length of the largest one is ignored. + + :param b1: + First bytes object. + :param b2: + Second bytes object. + :returns: + Bytes object, result of XOR operation. + """ + if PY2: + return ''.join(byte(ord(x) ^ ord(y)) for x, y in zip(b1, b2)) + + return bytes(x ^ y for x, y in zip(b1, b2)) + + def get_word_alignment(num, force_arch=64, _machine_word_size=MACHINE_WORD_SIZE): """ diff --git a/tests/test_compat.py b/tests/test_compat.py index a47f890..62e933f 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -17,10 +17,12 @@ import unittest import struct -from rsa._compat import byte, is_bytes, range +from rsa._compat import byte, is_bytes, range, xor_bytes class TestByte(unittest.TestCase): + """Tests for single bytes.""" + def test_byte(self): for i in range(256): byt = byte(i) @@ -33,3 +35,46 @@ class TestByte(unittest.TestCase): def test_byte_literal(self): self.assertIsInstance(b'abc', bytes) + + +class TestBytes(unittest.TestCase): + """Tests for bytes objects.""" + + def setUp(self): + self.b1 = b'\xff\xff\xff\xff' + self.b2 = b'\x00\x00\x00\x00' + self.b3 = b'\xf0\xf0\xf0\xf0' + self.b4 = b'\x4d\x23\xca\xe2' + self.b5 = b'\x9b\x61\x3b\xdc' + self.b6 = b'\xff\xff' + + self.byte_strings = (self.b1, self.b2, self.b3, self.b4, self.b5, self.b6) + + def test_xor_bytes(self): + self.assertEqual(xor_bytes(self.b1, self.b2), b'\xff\xff\xff\xff') + self.assertEqual(xor_bytes(self.b1, self.b3), b'\x0f\x0f\x0f\x0f') + self.assertEqual(xor_bytes(self.b1, self.b4), b'\xb2\xdc\x35\x1d') + self.assertEqual(xor_bytes(self.b1, self.b5), b'\x64\x9e\xc4\x23') + self.assertEqual(xor_bytes(self.b2, self.b3), b'\xf0\xf0\xf0\xf0') + self.assertEqual(xor_bytes(self.b2, self.b4), b'\x4d\x23\xca\xe2') + self.assertEqual(xor_bytes(self.b2, self.b5), b'\x9b\x61\x3b\xdc') + self.assertEqual(xor_bytes(self.b3, self.b4), b'\xbd\xd3\x3a\x12') + self.assertEqual(xor_bytes(self.b3, self.b5), b'\x6b\x91\xcb\x2c') + self.assertEqual(xor_bytes(self.b4, self.b5), b'\xd6\x42\xf1\x3e') + + def test_xor_bytes_length(self): + self.assertEqual(xor_bytes(self.b1, self.b6), b'\x00\x00') + self.assertEqual(xor_bytes(self.b2, self.b6), b'\xff\xff') + self.assertEqual(xor_bytes(self.b3, self.b6), b'\x0f\x0f') + self.assertEqual(xor_bytes(self.b4, self.b6), b'\xb2\xdc') + self.assertEqual(xor_bytes(self.b5, self.b6), b'\x64\x9e') + self.assertEqual(xor_bytes(self.b6, b''), b'') + + def test_xor_bytes_commutative(self): + for first in self.byte_strings: + for second in self.byte_strings: + min_length = min(len(first), len(second)) + result = xor_bytes(first, second) + + self.assertEqual(result, xor_bytes(second, first)) + self.assertEqual(len(result), min_length) |