diff options
Diffstat (limited to 'lib/Crypto/SelfTest/Cipher/common.py')
-rw-r--r-- | lib/Crypto/SelfTest/Cipher/common.py | 340 |
1 files changed, 325 insertions, 15 deletions
diff --git a/lib/Crypto/SelfTest/Cipher/common.py b/lib/Crypto/SelfTest/Cipher/common.py index a20a3aa..e52a781 100644 --- a/lib/Crypto/SelfTest/Cipher/common.py +++ b/lib/Crypto/SelfTest/Cipher/common.py @@ -30,8 +30,9 @@ __revision__ = "$Id$" import sys import unittest -from binascii import a2b_hex, b2a_hex +from binascii import a2b_hex, b2a_hex, hexlify from Crypto.Util.py3compat import * +from Crypto.Util.strxor import strxor_c # For compatibility with Python 2.1 and Python 2.2 if sys.hexversion < 0x02030000: @@ -68,14 +69,24 @@ class CipherSelfTest(unittest.TestCase): self.plaintext = b(_extract(params, 'plaintext')) self.ciphertext = b(_extract(params, 'ciphertext')) self.module_name = _extract(params, 'module_name', None) + self.assoc_data = _extract(params, 'assoc_data', None) + if self.assoc_data: + self.assoc_data = b(self.assoc_data) + self.mac = _extract(params, 'mac', None) + if self.assoc_data: + self.mac = b(self.mac) mode = _extract(params, 'mode', None) self.mode_name = str(mode) if mode is not None: # Block cipher self.mode = getattr(self.module, "MODE_" + mode) + self.iv = _extract(params, 'iv', None) - if self.iv is not None: self.iv = b(self.iv) + if self.iv is None: + self.iv = _extract(params, 'nonce', None) + if self.iv is not None: + self.iv = b(self.iv) # Only relevant for OPENPGP mode self.encrypted_iv = _extract(params, 'encrypted_iv', None) @@ -122,26 +133,49 @@ class CipherSelfTest(unittest.TestCase): def runTest(self): plaintext = a2b_hex(self.plaintext) ciphertext = a2b_hex(self.ciphertext) - - ct1 = b2a_hex(self._new().encrypt(plaintext)) - pt1 = b2a_hex(self._new(1).decrypt(ciphertext)) - ct2 = b2a_hex(self._new().encrypt(plaintext)) - pt2 = b2a_hex(self._new(1).decrypt(ciphertext)) + assoc_data = None + if self.assoc_data: + assoc_data = a2b_hex(self.assoc_data) + + ct = None + pt = None + + # + # Repeat the same encryption or decryption twice and verify + # that the result is always the same + # + for i in xrange(2): + cipher = self._new() + decipher = self._new(1) + + # Only AEAD modes + if self.assoc_data: + cipher.update(assoc_data) + decipher.update(assoc_data) + + ctX = b2a_hex(cipher.encrypt(plaintext)) + ptX = b2a_hex(decipher.decrypt(ciphertext)) + + if ct: + self.assertEqual(ct, ctX) + self.assertEqual(pt, ptX) + ct, pt = ctX, ptX if hasattr(self.module, "MODE_OPENPGP") and self.mode == self.module.MODE_OPENPGP: # In PGP mode, data returned by the first encrypt() # is prefixed with the encrypted IV. # Here we check it and then remove it from the ciphertexts. eilen = len(self.encrypted_iv) - self.assertEqual(self.encrypted_iv, ct1[:eilen]) - self.assertEqual(self.encrypted_iv, ct2[:eilen]) - ct1 = ct1[eilen:] - ct2 = ct2[eilen:] + self.assertEqual(self.encrypted_iv, ct[:eilen]) + ct = ct[eilen:] - self.assertEqual(self.ciphertext, ct1) # encrypt - self.assertEqual(self.ciphertext, ct2) # encrypt (second time) - self.assertEqual(self.plaintext, pt1) # decrypt - self.assertEqual(self.plaintext, pt2) # decrypt (second time) + self.assertEqual(self.ciphertext, ct) # encrypt + self.assertEqual(self.plaintext, pt) # decrypt + + if self.mac: + mac = b2a_hex(cipher.digest()) + self.assertEqual(self.mac, mac) + decipher.verify(a2b_hex(self.mac)) class CipherStreamingSelfTest(CipherSelfTest): @@ -252,6 +286,258 @@ class CFBSegmentSizeTest(unittest.TestCase): self.assertRaises(ValueError, self.module.new, a2b_hex(self.key), self.module.MODE_CFB, segment_size=i) self.module.new(a2b_hex(self.key), self.module.MODE_CFB, "\0"*self.module.block_size, segment_size=8) # should succeed +class CCMMACLengthTest(unittest.TestCase): + """CCM specific tests about MAC""" + + def __init__(self, module): + unittest.TestCase.__init__(self) + self.module = module + self.key = b('\xFF')*16 + self.iv = b('\x00')*10 + + def shortDescription(self): + return self.description + + def runTest(self): + """Verify that MAC can only be 4,6,8,..,16 bytes long.""" + for i in range(3,16,2): + self.description = "CCM MAC length check (%d bytes)" % i + self.assertRaises(ValueError, self.module.new, self.key, + self.module.MODE_CCM, self.iv, msg_len=10, mac_len=i) + + """Verify that default MAC length is 16.""" + self.description = "CCM default MAC length check" + cipher = self.module.new(self.key, self.module.MODE_CCM, + self.iv, msg_len=4) + cipher.encrypt(b('z')*4) + self.assertEqual(len(cipher.digest()), 16) + +class CCMSplitEncryptionTest(unittest.TestCase): + """CCM specific tests to validate how encrypt() + decrypt() can be called multiple times on the + same object.""" + + def __init__(self, module): + unittest.TestCase.__init__(self) + self.module = module + self.key = b('\xFF')*16 + self.iv = b('\x00')*10 + self.description = "CCM Split Encryption Test" + + def shortDescription(self): + return self.description + + def runTest(self): + """Verify that CCM update()/encrypt() can be called multiple times, + provided that lengths are declared beforehand""" + + data = b("AUTH DATA") + pt1 = b("PLAINTEXT1") # Short + pt2 = b("PLAINTEXT2") # Long + pt_ref = pt1+pt2 + + # REFERENCE: Run with 1 update() and 1 encrypt() + cipher = self.module.new(self.key, self.module.MODE_CCM, + self.iv) + cipher.update(data) + ct_ref = cipher.encrypt(pt_ref) + mac_ref = cipher.digest() + + # Verify that calling CCM encrypt()/decrypt() twice is not + # possible without the 'msg_len' parameter and regardless + # of the 'assoc_len' parameter + for ad_len in None, len(data): + cipher = self.module.new(self.key, self.module.MODE_CCM, + self.iv, assoc_len=ad_len) + cipher.update(data) + cipher.encrypt(pt1) + self.assertRaises(TypeError, cipher.encrypt, pt2) + + cipher = self.module.new(self.key, self.module.MODE_CCM, + self.iv, assoc_len=ad_len) + cipher.update(data) + cipher.decrypt(ct_ref[:len(pt1)]) + self.assertRaises(TypeError, cipher.decrypt, ct_ref[len(pt1):]) + + # Run with 2 encrypt()/decrypt(). Results must be the same + # regardless of the 'assoc_len' parameter + for ad_len in None, len(data): + cipher = self.module.new(self.key, self.module.MODE_CCM, + self.iv, assoc_len=ad_len, msg_len=len(pt_ref)) + cipher.update(data) + ct = cipher.encrypt(pt1) + ct += cipher.encrypt(pt2) + mac = cipher.digest() + self.assertEqual(ct_ref, ct) + self.assertEqual(mac_ref, mac) + + cipher = self.module.new(self.key, self.module.MODE_CCM, + self.iv, msg_len=len(pt1+pt2)) + cipher.update(data) + pt = cipher.decrypt(ct[:len(pt1)]) + pt += cipher.decrypt(ct[len(pt1):]) + mac = cipher.verify(mac_ref) + self.assertEqual(pt_ref, pt) + +class AEADTests(unittest.TestCase): + """Tests generic to all AEAD modes""" + + def __init__(self, module, mode_name): + unittest.TestCase.__init__(self) + self.module = module + self.mode_name = mode_name + self.mode = getattr(module, mode_name) + self.key = b('\xFF')*16 + self.iv = b('\x00')*10 + self.description = "AEAD Test" + + def right_mac_test(self): + """Positive tests for MAC""" + + self.description = "Test for right MAC in %s of %s" % \ + (self.mode_name, self.module.__name__) + + ad_ref = b("Reference AD") + pt_ref = b("Reference plaintext") + + # Encrypt and create the reference MAC + cipher = self.module.new(self.key, self.mode, self.iv) + cipher.update(ad_ref) + ct_ref = cipher.encrypt(pt_ref) + mac_ref = cipher.digest() + + # Decrypt and verify that MAC is accepted + decipher = self.module.new(self.key, self.mode, self.iv) + decipher.update(ad_ref) + pt = decipher.decrypt(ct_ref) + decipher.verify(mac_ref) + self.assertEqual(pt, pt_ref) + + # Verify that hexverify work + decipher.hexverify(hexlify(mac_ref)) + + def wrong_mac_test(self): + """Negative tests for MAC""" + + self.description = "Test for wrong MAC in %s of %s" % \ + (self.mode_name, self.module.__name__) + + ad_ref = b("Reference AD") + pt_ref = b("Reference plaintext") + + # Encrypt and create the reference MAC + cipher = self.module.new(self.key, self.mode, self.iv) + cipher.update(ad_ref) + ct_ref = cipher.encrypt(pt_ref) + mac_ref = cipher.digest() + + # Modify the MAC and verify it is NOT ACCEPTED + wrong_mac = strxor_c(mac_ref, 255) + decipher = self.module.new(self.key, self.mode, self.iv) + decipher.update(ad_ref) + pt = decipher.decrypt(ct_ref) + self.assertRaises(ValueError, decipher.verify, wrong_mac) + + def zero_data(self): + """Verify transition from INITIALIZED to FINISHED""" + + self.description = "Test for zero data in %s of %s" % \ + (self.mode_name, self.module.__name__) + cipher = self.module.new(self.key, self.mode, self.iv) + cipher.digest() + + def multiple_updates(self): + """Verify that update() can be called multiple times""" + + self.description = "Test for multiple updates in %s of %s" % \ + (self.mode_name, self.module.__name__) + + ad = b("").join([bchr(x) for x in xrange(0,128)]) + + mac1, mac2, mac3 = (None,)*3 + for chunk_length in 1,10,40,80,128: + chunks = [ad[i:i+chunk_length] for i in range(0, len(ad), chunk_length)] + + # No encryption/decryption + cipher = self.module.new(self.key, self.mode, self.iv) + for c in chunks: + cipher.update(c) + if mac1: + cipher.verify(mac1) + else: + mac1 = cipher.digest() + + # Encryption + cipher = self.module.new(self.key, self.mode, self.iv) + for c in chunks: + cipher.update(c) + ct = cipher.encrypt(b("PT")) + mac2 = cipher.digest() + + # Decryption + cipher = self.module.new(self.key, self.mode, self.iv) + for c in chunks: + cipher.update(c) + cipher.decrypt(ct) + cipher.verify(mac2) + + def no_mix_encrypt_decrypt(self): + """Verify that encrypt and decrypt cannot be mixed up""" + + self.description = "Test for mix of encrypt and decrypt in %s of %s" % \ + (self.mode_name, self.module.__name__) + + # Calling decrypt after encrypt raises an exception + cipher = self.module.new(self.key, self.mode, self.iv) + cipher.encrypt(b("PT")) + self.assertRaises(TypeError, cipher.decrypt, b("XYZ")) + + # Calling encrypt after decrypt raises an exception + cipher = self.module.new(self.key, self.mode, self.iv) + cipher.decrypt(b("CT")) + self.assertRaises(TypeError, cipher.encrypt, b("XYZ")) + + # Calling verify after encrypt raises an exception + cipher = self.module.new(self.key, self.mode, self.iv) + cipher.encrypt(b("PT")) + self.assertRaises(TypeError, cipher.verify, b("XYZ")) + self.assertRaises(TypeError, cipher.hexverify, "12") + + # Calling digest after decrypt raises an exception + cipher = self.module.new(self.key, self.mode, self.iv) + cipher.decrypt(b("CT")) + self.assertRaises(TypeError, cipher.digest) + self.assertRaises(TypeError, cipher.hexdigest) + + def no_late_update(self): + """Verify that update cannot be called after encrypt or decrypt""" + + self.description = "Test for late update in %s of %s" % \ + (self.mode_name, self.module.__name__) + + # Calling update after encrypt raises an exception + cipher = self.module.new(self.key, self.mode, self.iv) + cipher.update(b("XX")) + cipher.encrypt(b("PT")) + self.assertRaises(TypeError, cipher.update, b("XYZ")) + + # Calling update after decrypt raises an exception + cipher = self.module.new(self.key, self.mode, self.iv) + cipher.update(b("XX")) + cipher.decrypt(b("CT")) + self.assertRaises(TypeError, cipher.update, b("XYZ")) + + def runTest(self): + self.right_mac_test() + self.wrong_mac_test() + self.zero_data() + self.multiple_updates() + self.no_mix_encrypt_decrypt() + self.no_late_update() + + def shortDescription(self): + return self.description + class RoundtripTest(unittest.TestCase): def __init__(self, module, params): from Crypto import Random @@ -310,6 +596,10 @@ class IVLengthTest(unittest.TestCase): self.module.MODE_OFB, "") self.assertRaises(ValueError, self.module.new, a2b_hex(self.key), self.module.MODE_OPENPGP, "") + if hasattr(self.module, "MODE_CCM"): + for ivlen in (0,6,14): + self.assertRaises(ValueError, self.module.new, a2b_hex(self.key), + self.module.MODE_CCM, bchr(0)*ivlen, msg_len=10) self.module.new(a2b_hex(self.key), self.module.MODE_ECB, "") self.module.new(a2b_hex(self.key), self.module.MODE_CTR, "", counter=self._dummy_counter) @@ -367,6 +657,13 @@ def make_block_tests(module, module_name, test_data, additional_params=dict()): ] extra_tests_added = 1 + # Extract associated data and MAC for AEAD modes + if p_mode == 'CCM': + assoc_data, params['plaintext'] = params['plaintext'].split('|') + assoc_data2, params['ciphertext'], params['mac'] = params['ciphertext'].split('|') + params['assoc_data'] = assoc_data + params['mac_len'] = len(params['mac'])>>1 + # Add the current test to the test suite tests.append(CipherSelfTest(module, params)) @@ -383,6 +680,19 @@ def make_block_tests(module, module_name, test_data, additional_params=dict()): if not params2['ctr_params'].has_key('disable_shortcut'): params2['ctr_params']['disable_shortcut'] = 1 tests.append(CipherSelfTest(module, params2)) + + # Add tests that don't use test vectors + if hasattr(module, "MODE_CCM"): + tests += [ + CCMMACLengthTest(module), + CCMSplitEncryptionTest(module), + ] + for aead_mode in ("MODE_CCM",): + if hasattr(module, aead_mode): + tests += [ + AEADTests(module, aead_mode), + ] + return tests def make_stream_tests(module, module_name, test_data): |