summaryrefslogtreecommitdiff
path: root/test/tpm_test/rsa_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/tpm_test/rsa_test.py')
-rw-r--r--test/tpm_test/rsa_test.py110
1 files changed, 32 insertions, 78 deletions
diff --git a/test/tpm_test/rsa_test.py b/test/tpm_test/rsa_test.py
index 1d377b3ae2..52754b38d8 100644
--- a/test/tpm_test/rsa_test.py
+++ b/test/tpm_test/rsa_test.py
@@ -8,7 +8,6 @@
import binascii
import os
import struct
-
import Crypto
import Crypto.Hash.SHA
import Crypto.Hash.SHA256
@@ -89,83 +88,38 @@ _KEYS = {
# 0x00 LSB DIGEST LEN
# .... DIGEST
#
-_RSA_CMD_FORMAT = '{o:c}{p:c}{h:c}{kl:s}{ml:s}{msg}{dl:s}{dig}'
-
-
def _decrypt_cmd(padding, hashing, key_len, msg):
- rsa_op = _RSA_OPCODES['DECRYPT']
- msg_len = len(msg)
- return _RSA_CMD_FORMAT.format(o=rsa_op, p=padding, h=hashing,
- kl=struct.pack('>H', key_len),
- ml=struct.pack('>H', msg_len), msg=msg,
- dl='', dig='')
-
+ return struct.pack('>BBBHH', _RSA_OPCODES['DECRYPT'], padding, hashing,
+ key_len, len(msg)) + msg + bytes([0, 0])
def _encrypt_cmd(padding, hashing, key_len, msg):
- rsa_op = _RSA_OPCODES['ENCRYPT']
- msg_len = len(msg)
- return _RSA_CMD_FORMAT.format(o=rsa_op, p=padding, h=hashing,
- kl=struct.pack('>H', key_len),
- ml=struct.pack('>H', msg_len), msg=msg,
- dl='', dig='')
-
+ return struct.pack('>BBBHH', _RSA_OPCODES['ENCRYPT'], padding, hashing,
+ key_len, len(msg)) + msg + bytes([0, 0])
def _sign_cmd(padding, hashing, key_len, digest):
- rsa_op = _RSA_OPCODES['SIGN']
- digest_len = len(digest)
- return _RSA_CMD_FORMAT.format(o=rsa_op, p=padding, h=hashing,
- kl=struct.pack('>H', key_len),
- ml=struct.pack('>H', digest_len), msg=digest,
- dl='', dig='')
-
+ return struct.pack('>BBBHH', _RSA_OPCODES['SIGN'], padding, hashing,
+ key_len, len(digest)) + digest + bytes([0, 0])
def _verify_cmd(padding, hashing, key_len, sig, digest):
- rsa_op = _RSA_OPCODES['VERIFY']
- sig_len = len(sig)
- digest_len = len(digest)
- return _RSA_CMD_FORMAT.format(o=rsa_op, p=padding, h=hashing,
- kl=struct.pack('>H', key_len),
- ml=struct.pack('>H', sig_len), msg=sig,
- dl=struct.pack('>H', digest_len), dig=digest)
-
+ return struct.pack('>BBBHH', _RSA_OPCODES['VERIFY'], padding, hashing,
+ key_len, len(sig)) + sig +\
+ len(digest).to_bytes(2, 'big') + digest
def _keytest_cmd(key_len):
- rsa_op = _RSA_OPCODES['KEYTEST']
- return _RSA_CMD_FORMAT.format(o=rsa_op, p=0, h=_HASH['NONE'],
- kl=struct.pack('>H', key_len),
- ml=struct.pack('>H', 0), msg='',
- dl='', dig='')
-
+ return struct.pack('>BBBHHH', _RSA_OPCODES['KEYTEST'], 0, 0, key_len, 0, 0)
def _keygen_cmd(key_len, exponent, label):
assert exponent == 65537
- rsa_op = _RSA_OPCODES['KEYGEN']
- padding = _RSA_PADDING['NONE']
- hashing = _HASH['NONE']
- return _RSA_CMD_FORMAT.format(o=rsa_op, p=padding, h=hashing,
- kl=struct.pack('>H', key_len),
- ml=struct.pack('>H', len(label)), msg=label,
- dl=struct.pack('>H', 0), dig='')
-
+ return struct.pack('>BBBHH', _RSA_OPCODES['KEYGEN'], 0, 0, key_len,
+ len(label)) + label + bytes([0, 0])
def _primegen_cmd(seed):
- rsa_op = _RSA_OPCODES['PRIMEGEN']
- padding = _RSA_PADDING['NONE']
- hashing = _HASH['NONE']
- return _RSA_CMD_FORMAT.format(o=rsa_op, p=padding, h=hashing,
- kl=struct.pack('>H', len(seed) * 8 * 2),
- ml=struct.pack('>H', len(seed)), msg=seed,
- dl=struct.pack('>H', 0), dig='')
+ return struct.pack('>BBBHH', _RSA_OPCODES['PRIMEGEN'], 0, 0, len(seed) * 16,
+ len(seed)) + seed + bytes([0, 0])
def _x509_verify_cmd(key_len):
- rsa_op = _RSA_OPCODES['X509_VERIFY']
- padding = _RSA_PADDING['NONE']
- hashing = _HASH['NONE']
- return _RSA_CMD_FORMAT.format(o=rsa_op, p=padding, h=hashing,
- kl=struct.pack('>H', key_len),
- ml=struct.pack('>H', 0), msg='',
- dl=struct.pack('>H', 0), dig='')
-
+ return struct.pack('>BBBHHH', _RSA_OPCODES['X509_VERIFY'], 0, 0,
+ key_len, 0, 0)
_PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53,
59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127,
@@ -579,13 +533,13 @@ def _prime_from_seed(seed):
return window
# Set LSB, and top two bits.
- candidate = chr(ord(seed[0]) | 192) + seed[1:-1] + chr(ord(seed[-1]) | 1)
+ candidate = bytes([(seed[0] | 192)]) + seed[1:-1] + bytes([seed[-1] | 1])
candidate = int(binascii.b2a_hex(candidate), 16)
assert len(bin(candidate)[2:]) == len(seed) * 8
window = _window(candidate, _PRIMES[:4096])
for i, bit in enumerate(window):
if not bit:
- if rsa.prime.randomized_primality_testing(candidate + i, rounds):
+ if rsa.prime.miller_rabin_primality_testing(candidate + i, rounds):
return candidate + i
return None
@@ -633,11 +587,11 @@ _KEYTEST_INPUTS = (
)
_KEYGEN_INPUTS = (
- (768, 65537, '', None),
- (1024, 65537, 'rsa_test', None),
+ (768, 65537, b'', None),
+ (1024, 65537, b'rsa_test', None),
# pylint: disable=line-too-long
- (2048, 65537, 'RSA key by vendor', 20811475686431332186511278472307159547870512766846593830860105577496044159545322178313772755518365593670114793803805067608811418757734989708137784444223785391864604211835387393923163468734914392307047296990698533218399115126417934050463597455237478939601236799120239663591264311485133747167378663829046579164891864068853210530642835833947569643788911200934265596274935082689832626616967124524353322373059893974744194447740045242468136414689225322177212281193879756355471091445748150740871146034049776312457888356154834233819876846764944450478069436248506560967902863015152471662817623176815923756421011384149834497587),
- (2048, 65537, '', None),
+ (2048, 65537, b'RSA key by vendor', 20811475686431332186511278472307159547870512766846593830860105577496044159545322178313772755518365593670114793803805067608811418757734989708137784444223785391864604211835387393923163468734914392307047296990698533218399115126417934050463597455237478939601236799120239663591264311485133747167378663829046579164891864068853210530642835833947569643788911200934265596274935082689832626616967124524353322373059893974744194447740045242468136414689225322177212281193879756355471091445748150740871146034049776312457888356154834233819876846764944450478069436248506560967902863015152471662817623176815923756421011384149834497587),
+ (2048, 65537, b'', None),
)
# 2048-bit will be done in hardware (i.e. fast), rest are in software.
@@ -653,7 +607,7 @@ _PRIMEGEN_INPUTS = (
)
def _encrypt_tests(tpm):
- msg = 'Hello CR50!'
+ msg = b'Hello CR50!'
for data in _ENCRYPT_INPUTS:
padding, hashing, key_len = data
@@ -710,7 +664,7 @@ def _verify_tests(tpm):
key_len, signature, msg_hash.digest())
wrapped_response = tpm.command(tpm.wrap_ext_command(subcmd.RSA, cmd))
verified = tpm.unwrap_ext_response(subcmd.RSA, wrapped_response)
- expected = '\x01'
+ expected = b'\x01'
if verified != expected:
raise subcmd.TpmTestError('%s error:%s%s' % (
test_name, utils.hex_dump(verified), utils.hex_dump(expected)))
@@ -724,7 +678,7 @@ def _keytest_tests(tpm):
cmd = _keytest_cmd(key_len)
wrapped_response = tpm.command(tpm.wrap_ext_command(subcmd.RSA, cmd))
valid = tpm.unwrap_ext_response(subcmd.RSA, wrapped_response)
- expected = '\x01'
+ expected = b'\x01'
if valid != expected:
raise subcmd.TpmTestError('%s error:%s%s' % (
test_name, utils.hex_dump(valid), utils.hex_dump(expected)))
@@ -746,12 +700,12 @@ def _keygen_tests(tpm):
raise subcmd.TpmTestError('%s error:%s' % (
test_name, utils.hex_dump(result)))
- N = int(binascii.b2a_hex(result[0:result_len * 2 / 3]), 16)
+ N = int(binascii.b2a_hex(result[0:result_len * 2 // 3]), 16)
if expected_N and N != expected_N:
raise subcmd.TpmTestError('%s error:%s' %
(test_name, utils.hex_dump(result)))
- p = int(binascii.b2a_hex(result[result_len * 2 / 3:]), 16)
- q = N / p
+ p = int(binascii.b2a_hex(result[result_len * 2 // 3:]), 16)
+ q = N // p
if not rsa.prime.is_prime(p):
raise subcmd.TpmTestError('%s error:%s' %
(test_name, utils.hex_dump(result)))
@@ -768,15 +722,15 @@ def _primegen_tests(tpm):
for data in _PRIMEGEN_INPUTS:
key_len = data
test_name = 'RSA-PRIMEGEN:%d' % data
- seed = rsa.randnum.read_random_bits(key_len / 2)
- assert len(seed) == key_len / 16
+ seed = rsa.randnum.read_random_bits(key_len // 2)
+ assert len(seed) == key_len // 16
# dcrypto interface is little-endian.
cmd = _primegen_cmd(seed[::-1])
wrapped_response = tpm.command(tpm.wrap_ext_command(subcmd.RSA, cmd))
result = tpm.unwrap_ext_response(subcmd.RSA, wrapped_response)
result_len = len(result)
- if result_len != key_len / 16:
+ if result_len != key_len // 16:
raise subcmd.TpmTestError('%s error:%s' % (
test_name, utils.hex_dump(result)))
@@ -796,7 +750,7 @@ def _x509_verify_tests(tpm):
cmd = _x509_verify_cmd(2048)
wrapped_response = tpm.command(tpm.wrap_ext_command(subcmd.RSA, cmd))
valid = tpm.unwrap_ext_response(subcmd.RSA, wrapped_response)
- expected = '\x01'
+ expected = b'\x01'
if valid != expected:
raise subcmd.TpmTestError('%s error:%s%s' % (
test_name, utils.hex_dump(valid), utils.hex_dump(expected)))