summaryrefslogtreecommitdiff
path: root/board/cr50/dcrypto/rsa.c
diff options
context:
space:
mode:
Diffstat (limited to 'board/cr50/dcrypto/rsa.c')
-rw-r--r--board/cr50/dcrypto/rsa.c302
1 files changed, 172 insertions, 130 deletions
diff --git a/board/cr50/dcrypto/rsa.c b/board/cr50/dcrypto/rsa.c
index 9c721abc5e..524d3d14f4 100644
--- a/board/cr50/dcrypto/rsa.c
+++ b/board/cr50/dcrypto/rsa.c
@@ -31,9 +31,9 @@ static uint32_t select(uint32_t mask, uint32_t a, uint32_t b)
/* We use SHA256 context to store SHA1 context, so make sure it's ok. */
BUILD_ASSERT(sizeof(struct sha256_ctx) >= sizeof(struct sha1_ctx));
-static void MGF1_xor(uint8_t *dst, uint32_t dst_len,
- const uint8_t *seed, uint32_t seed_len,
- enum hashing_mode hashing)
+static enum dcrypto_result MGF1_xor(uint8_t *dst, uint32_t dst_len,
+ const uint8_t *seed, uint32_t seed_len,
+ enum hashing_mode hashing)
{
union hash_ctx ctx;
@@ -44,17 +44,19 @@ static void MGF1_xor(uint8_t *dst, uint32_t dst_len,
uint8_t b0;
} cnt;
const uint8_t *digest;
- const size_t hash_size = (hashing == HASH_SHA1) ? SHA1_DIGEST_SIZE :
- SHA256_DIGEST_SIZE;
+ const size_t hash_size = DCRYPTO_hash_size(hashing);
+
+ if (!hash_size)
+ return DCRYPTO_FAIL;
cnt.b0 = cnt.b1 = cnt.b2 = cnt.b3 = 0;
while (dst_len) {
size_t i;
+ enum dcrypto_result result;
- if (hashing == HASH_SHA1)
- SHA1_hw_init(&ctx.sha1);
- else
- SHA256_hw_init(&ctx.sha256);
+ result = DCRYPTO_hw_hash_init(&ctx, hashing);
+ if (result != DCRYPTO_OK)
+ return result;
HASH_update(&ctx, seed, seed_len);
HASH_update(&ctx, (uint8_t *)&cnt, sizeof(cnt));
@@ -65,6 +67,7 @@ static void MGF1_xor(uint8_t *dst, uint32_t dst_len,
if (!++cnt.b0)
++cnt.b1;
}
+ return DCRYPTO_OK;
}
/*
@@ -78,12 +81,12 @@ static void MGF1_xor(uint8_t *dst, uint32_t dst_len,
* };
*/
/* encrypt */
-static int oaep_pad(uint8_t *output, uint32_t output_len,
- const uint8_t *msg, uint32_t msg_len,
- enum hashing_mode hashing, const char *label)
+static enum dcrypto_result oaep_pad(uint8_t *output, uint32_t output_len,
+ const uint8_t *msg, uint32_t msg_len,
+ enum hashing_mode hashing,
+ const char *label)
{
- const size_t hash_size = (hashing == HASH_SHA1) ? SHA_DIGEST_SIZE
- : SHA256_DIGEST_SIZE;
+ const size_t hash_size = DCRYPTO_hash_size(hashing);
uint8_t *const seed = output + 1;
uint8_t *const phash = seed + hash_size;
uint8_t *const PS = phash + hash_size;
@@ -91,39 +94,44 @@ static int oaep_pad(uint8_t *output, uint32_t output_len,
const uint32_t ps_len = max_msg_len - msg_len;
uint8_t *const one = PS + ps_len;
union hash_ctx ctx;
+ enum dcrypto_result result;
+ if (!hash_size)
+ return DCRYPTO_FAIL;
if (output_len < 2 + 2 * hash_size)
- return 0; /* Key size too small for chosen hash. */
+ return DCRYPTO_FAIL; /* Key size too small for chosen hash. */
if (msg_len > output_len - 2 - 2 * hash_size)
- return 0; /* Input message too large for key size. */
+ return DCRYPTO_FAIL; /* Input message too large for key size. */
always_memset(output, 0, output_len);
if (!fips_rand_bytes(seed, hash_size))
- return 0;
+ return DCRYPTO_FAIL;
- if (hashing == HASH_SHA1)
- SHA1_hw_init(&ctx.sha1);
- else
- SHA256_hw_init(&ctx.sha256);
+ result = DCRYPTO_hw_hash_init(&ctx, hashing);
+ if (result != DCRYPTO_OK)
+ return result;
HASH_update(&ctx, label, label ? strlen(label) + 1 : 0);
memcpy(phash, HASH_final(&ctx)->b8, hash_size);
*one = 1;
memcpy(one + 1, msg, msg_len);
- MGF1_xor(phash, hash_size + 1 + max_msg_len,
- seed, hash_size, hashing);
- MGF1_xor(seed, hash_size, phash, hash_size + 1 + max_msg_len,
- hashing);
- return 1;
+ result = MGF1_xor(phash, hash_size + 1 + max_msg_len, seed, hash_size,
+ hashing);
+ result |= MGF1_xor(seed, hash_size, phash, hash_size + 1 + max_msg_len,
+ hashing);
+
+ if (result != DCRYPTO_OK)
+ return DCRYPTO_FAIL;
+
+ return result;
}
/* decrypt */
-static int check_oaep_pad(uint8_t *out, size_t *out_len,
+static enum dcrypto_result check_oaep_pad(uint8_t *out, size_t *out_len,
uint8_t *padded, size_t padded_len,
enum hashing_mode hashing, const char *label)
{
- const size_t hash_size = (hashing == HASH_SHA1) ? SHA_DIGEST_SIZE
- : SHA256_DIGEST_SIZE;
+ const size_t hash_size = DCRYPTO_hash_size(hashing);
uint8_t *seed = padded + 1;
uint8_t *phash = seed + hash_size;
uint8_t *PS = phash + hash_size;
@@ -133,27 +141,35 @@ static int check_oaep_pad(uint8_t *out, size_t *out_len,
uint32_t looking_for_one_byte = ~0;
int bad;
size_t i;
+ enum dcrypto_result result;
+
+ if (!hash_size)
+ return DCRYPTO_FAIL;
if (padded_len < 2 + 2 * hash_size)
- return 0; /* Invalid input size. */
+ return DCRYPTO_FAIL; /* Invalid input size. */
/* Recover seed. */
- MGF1_xor(seed, hash_size, phash, hash_size + 1 + max_msg_len, hashing);
+ result = MGF1_xor(seed, hash_size, phash, hash_size + 1 + max_msg_len,
+ hashing);
/* Recover db. */
- MGF1_xor(phash, hash_size + 1 + max_msg_len, seed, hash_size, hashing);
+ result |= MGF1_xor(phash, hash_size + 1 + max_msg_len, seed, hash_size,
+ hashing);
+
+ if (result != DCRYPTO_OK)
+ return DCRYPTO_FAIL;
- if (hashing == HASH_SHA1)
- SHA1_hw_init(&ctx.sha1);
- else
- SHA256_hw_init(&ctx.sha256);
+ result = DCRYPTO_hw_hash_init(&ctx, hashing);
+ if (result != DCRYPTO_OK)
+ return result;
HASH_update(&ctx, label, label ? strlen(label) + 1 : 0);
- /* bad should be zero if CRYPTO_OK is returned. */
- bad = DCRYPTO_equals(phash, HASH_final(&ctx)->b8, hash_size) -
- DCRYPTO_OK;
+ /* bad should be zero if DCRYPTO_OK is returned. */
+ result = DCRYPTO_equals(phash, HASH_final(&ctx)->b8, hash_size);
+ bad = result - DCRYPTO_OK; /* bad = 0 if result == DCRYPTO_OK */
bad |= padded[0];
- for (i = PS - padded; i < padded_len; i++) {
+ for (i = PS - padded; i < padded_len; i++) {
uint32_t equals0 = is_zero(padded[i]);
uint32_t equals1 = is_zero(padded[i] ^ 1);
@@ -168,29 +184,30 @@ static int check_oaep_pad(uint8_t *out, size_t *out_len,
bad |= looking_for_one_byte;
if (bad)
- return 0;
+ return DCRYPTO_FAIL;
one_index++;
if (*out_len < padded_len - one_index)
- return 0;
+ return DCRYPTO_FAIL;
memcpy(out, padded + one_index, padded_len - one_index);
*out_len = padded_len - one_index;
- return 1;
+ /* Result should be DCRYPTO_OK after DCRYPTO_equals() */
+ return result;
}
/* Constants from RFC 3447. */
#define RSA_PKCS1_PADDING_SIZE 11
/* encrypt */
-static int pkcs1_type2_pad(uint8_t *padded, size_t padded_len,
- const uint8_t *in, size_t in_len)
+static enum dcrypto_result pkcs1_type2_pad(uint8_t *padded, size_t padded_len,
+ const uint8_t *in, size_t in_len)
{
size_t PS_len;
if (padded_len < RSA_PKCS1_PADDING_SIZE)
- return 0;
+ return DCRYPTO_FAIL;
if (in_len > padded_len - RSA_PKCS1_PADDING_SIZE)
- return 0;
+ return DCRYPTO_FAIL;
PS_len = padded_len - 3 - in_len;
*(padded++) = 0;
@@ -200,7 +217,7 @@ static int pkcs1_type2_pad(uint8_t *padded, size_t padded_len,
uint8_t r[SHA256_DIGEST_SIZE];
if (!fips_rand_bytes(r, sizeof(r)))
- return 0;
+ return DCRYPTO_FAIL;
/**
* zero byte has special meaning in PKCS1, so copy
@@ -215,11 +232,11 @@ static int pkcs1_type2_pad(uint8_t *padded, size_t padded_len,
}
*(padded++) = 0;
memcpy(padded, in, in_len);
- return 1;
+ return DCRYPTO_OK;
}
/* decrypt */
-static int check_pkcs1_type2_pad(uint8_t *out, size_t *out_len,
+static enum dcrypto_result check_pkcs1_type2_pad(uint8_t *out, size_t *out_len,
const uint8_t *padded, size_t padded_len)
{
size_t i;
@@ -228,7 +245,7 @@ static int check_pkcs1_type2_pad(uint8_t *out, size_t *out_len,
uint32_t looking_for_index = ~0;
if (padded_len < RSA_PKCS1_PADDING_SIZE)
- return 0;
+ return DCRYPTO_FAIL;
valid = (padded[0] == 0);
valid &= (padded[1] == 2);
@@ -245,13 +262,14 @@ static int check_pkcs1_type2_pad(uint8_t *out, size_t *out_len,
valid &= ~looking_for_index;
valid &= (zero_index >= RSA_PKCS1_PADDING_SIZE);
if (!valid)
- return 0;
+ return DCRYPTO_FAIL;
if (*out_len < padded_len - zero_index)
- return 0;
+ return DCRYPTO_FAIL;
+
memcpy(out, &padded[zero_index], padded_len - zero_index);
*out_len = padded_len - zero_index;
- return 1;
+ return DCRYPTO_OK;
}
static const uint8_t SHA1_DER[] = {
@@ -274,8 +292,9 @@ static const uint8_t SHA512_DER[] = {
0x00, 0x04, 0x40
};
-static int pkcs1_get_der(enum hashing_mode hashing, const uint8_t **der,
- size_t *der_size, size_t *hash_size)
+static enum dcrypto_result pkcs1_get_der(enum hashing_mode hashing,
+ const uint8_t **der, size_t *der_size,
+ size_t *hash_size)
{
switch (hashing) {
case HASH_SHA1:
@@ -301,33 +320,35 @@ static int pkcs1_get_der(enum hashing_mode hashing, const uint8_t **der,
case HASH_NULL:
*der = NULL;
*der_size = 0;
- *hash_size = 0; /* any size allowed */
+ *hash_size = 0; /* any size allowed */
break;
default:
- return 0;
+ return DCRYPTO_FAIL;
}
- return 1;
+ return DCRYPTO_OK;
}
/* sign */
-static int pkcs1_type1_pad(uint8_t *padded, size_t padded_len,
- const uint8_t *in, size_t in_len,
- enum hashing_mode hashing)
+static enum dcrypto_result pkcs1_type1_pad(uint8_t *padded, size_t padded_len,
+ const uint8_t *in, size_t in_len,
+ enum hashing_mode hashing)
{
const uint8_t *der;
size_t der_size;
size_t hash_size;
size_t ps_len;
+ enum dcrypto_result result;
- if (!pkcs1_get_der(hashing, &der, &der_size, &hash_size))
- return 0;
+ result = pkcs1_get_der(hashing, &der, &der_size, &hash_size);
+ if (result != DCRYPTO_OK)
+ return result;
if (padded_len < RSA_PKCS1_PADDING_SIZE + der_size)
- return 0;
+ return DCRYPTO_FAIL;
if (!in_len || (hash_size && in_len != hash_size))
- return 0;
+ return DCRYPTO_FAIL;
if (in_len > padded_len - RSA_PKCS1_PADDING_SIZE - der_size)
- return 0;
+ return DCRYPTO_FAIL;
ps_len = padded_len - 3 - der_size - in_len;
*(padded++) = 0;
@@ -338,73 +359,83 @@ static int pkcs1_type1_pad(uint8_t *padded, size_t padded_len,
memcpy(padded, der, der_size);
padded += der_size;
memcpy(padded, in, in_len);
- return 1;
+ return DCRYPTO_OK;
}
/* verify */
-static int check_pkcs1_type1_pad(const uint8_t *msg, size_t msg_len,
- const uint8_t *padded, size_t padded_len,
- enum hashing_mode hashing)
+static enum dcrypto_result check_pkcs1_type1_pad(const uint8_t *msg,
+ size_t msg_len,
+ const uint8_t *padded,
+ size_t padded_len,
+ enum hashing_mode hashing)
{
size_t i;
const uint8_t *der;
size_t der_size;
size_t hash_size;
size_t ps_len;
+ enum dcrypto_result result;
- if (!pkcs1_get_der(hashing, &der, &der_size, &hash_size))
- return 0;
+ result = pkcs1_get_der(hashing, &der, &der_size, &hash_size);
+ if (result != DCRYPTO_OK)
+ return result;
if (msg_len != hash_size)
- return 0;
+ return DCRYPTO_FAIL;
if (padded_len < RSA_PKCS1_PADDING_SIZE + der_size + hash_size)
- return 0;
+ return DCRYPTO_FAIL;
ps_len = padded_len - 3 - der_size - hash_size;
if (padded[0] != 0 || padded[1] != 1)
- return 0;
+ return DCRYPTO_FAIL;
for (i = 2; i < ps_len + 2; i++) {
if (padded[i] != 0xFF)
- return 0;
+ return DCRYPTO_FAIL;
}
if (padded[i++] != 0)
- return 0;
- if (DCRYPTO_equals(&padded[i], der, der_size) != DCRYPTO_OK)
- return 0;
+ return DCRYPTO_FAIL;
+
+ result = DCRYPTO_equals(&padded[i], der, der_size);
i += der_size;
- return DCRYPTO_equals(msg, &padded[i], hash_size) == DCRYPTO_OK;
+ result |= DCRYPTO_equals(msg, &padded[i], hash_size);
+ if (result != DCRYPTO_OK)
+ return DCRYPTO_FAIL;
+ return result;
}
/* sign */
-static int pkcs1_pss_pad(uint8_t *padded, size_t padded_len,
- const uint8_t *in, size_t in_len,
- enum hashing_mode hashing)
+static enum dcrypto_result pkcs1_pss_pad(uint8_t *padded, size_t padded_len,
+ const uint8_t *in, size_t in_len,
+ enum hashing_mode hashing)
{
- const uint32_t hash_size = (hashing == HASH_SHA1) ? SHA1_DIGEST_SIZE
- : SHA256_DIGEST_SIZE;
+ const uint32_t hash_size = DCRYPTO_hash_size(hashing);
const uint32_t salt_len = MIN(padded_len - hash_size - 2, hash_size);
size_t db_len;
size_t ps_len;
union hash_ctx ctx;
+ enum dcrypto_result result;
+ if (!hash_size)
+ return DCRYPTO_FAIL;
if (in_len != hash_size)
- return 0;
+ return DCRYPTO_FAIL;
if (padded_len < hash_size + 2)
- return 0;
+ return DCRYPTO_FAIL;
db_len = padded_len - hash_size - 1;
- if (hashing == HASH_SHA1)
- SHA1_hw_init(&ctx.sha1);
- else
- SHA256_hw_init(&ctx.sha256);
+ result = DCRYPTO_hw_hash_init(&ctx, hashing);
+ if (result != DCRYPTO_OK)
+ return result;
/* Pilfer bits of output for temporary use. */
memset(padded, 0, 8);
HASH_update(&ctx, padded, 8);
HASH_update(&ctx, in, in_len);
/* Pilfer bits of output for temporary use. */
- if (!fips_rand_bytes(padded, salt_len))
- return 0;
+ if (!fips_rand_bytes(padded, salt_len)) {
+ HASH_final(&ctx); /* free up SHA engine */
+ return DCRYPTO_FAIL;
+ }
HASH_update(&ctx, padded, salt_len);
/* Output hash. */
@@ -415,22 +446,22 @@ static int pkcs1_pss_pad(uint8_t *padded, size_t padded_len,
memmove(padded + ps_len + 1, padded, salt_len);
memset(padded, 0, ps_len);
padded[ps_len] = 0x01;
- MGF1_xor(padded, db_len, padded + db_len, hash_size, hashing);
+ result = MGF1_xor(padded, db_len, padded + db_len, hash_size, hashing);
/* Clear most significant bit. */
padded[0] &= 0x7F;
/* Set trailing byte. */
padded[padded_len - 1] = 0xBC;
- return 1;
+ return result;
}
/* verify */
-static int check_pkcs1_pss_pad(const uint8_t *in, size_t in_len,
- uint8_t *padded, size_t padded_len,
- enum hashing_mode hashing)
+static enum dcrypto_result check_pkcs1_pss_pad(const uint8_t *in, size_t in_len,
+ uint8_t *padded,
+ size_t padded_len,
+ enum hashing_mode hashing)
{
- const uint32_t hash_size = (hashing == HASH_SHA1) ? SHA1_DIGEST_SIZE
- : SHA256_DIGEST_SIZE;
+ const uint32_t hash_size = DCRYPTO_hash_size(hashing);
const uint8_t zeros[8] = {0, 0, 0, 0, 0, 0, 0, 0};
uint32_t db_len;
uint32_t max_ps_len;
@@ -438,11 +469,14 @@ static int check_pkcs1_pss_pad(const uint8_t *in, size_t in_len,
union hash_ctx ctx;
int bad = 0;
size_t i;
+ enum dcrypto_result result;
+ if (!hash_size)
+ return DCRYPTO_FAIL;
if (in_len != hash_size)
- return 0;
+ return DCRYPTO_FAIL;
if (padded_len < hash_size + 2)
- return 0;
+ return DCRYPTO_FAIL;
db_len = padded_len - hash_size - 1;
/* Top bit should be zero. */
@@ -451,7 +485,9 @@ static int check_pkcs1_pss_pad(const uint8_t *in, size_t in_len,
bad |= padded[padded_len - 1] ^ 0xBC;
/* Recover DB. */
- MGF1_xor(padded, db_len, padded + db_len, hash_size, hashing);
+ result = MGF1_xor(padded, db_len, padded + db_len, hash_size, hashing);
+ bad |= result - DCRYPTO_OK;
+
/* Clear top bit. */
padded[0] &= 0x7F;
/* Verify padding2. */
@@ -466,16 +502,18 @@ static int check_pkcs1_pss_pad(const uint8_t *in, size_t in_len,
/* Continue with zero-length salt if 0x01 was not found. */
salt_len = max_ps_len - i;
- if (hashing == HASH_SHA1)
- SHA1_hw_init(&ctx.sha1);
- else
- SHA256_hw_init(&ctx.sha256);
+ result |= DCRYPTO_hw_hash_init(&ctx, hashing);
+ if (result != DCRYPTO_OK)
+ return DCRYPTO_FAIL;
+
HASH_update(&ctx, zeros, sizeof(zeros));
HASH_update(&ctx, in, in_len);
HASH_update(&ctx, padded + db_len - salt_len, salt_len);
- bad |= DCRYPTO_equals(padded + db_len, HASH_final(&ctx), hash_size) -
- DCRYPTO_OK;
- return !bad;
+ result |= DCRYPTO_equals(padded + db_len, HASH_final(&ctx), hash_size);
+ bad |= result - DCRYPTO_OK;
+ if (bad)
+ result = DCRYPTO_FAIL;
+ return result;
}
static int check_modulus_params(
@@ -511,13 +549,14 @@ int DCRYPTO_rsa_encrypt(struct RSA *rsa, uint8_t *out, size_t *out_len,
switch (padding) {
case PADDING_MODE_OAEP:
- if (!oaep_pad((uint8_t *) padded.d, bn_size(&padded),
- (const uint8_t *) in, in_len, hashing, label))
+ if (oaep_pad((uint8_t *)padded.d, bn_size(&padded),
+ (const uint8_t *)in, in_len, hashing,
+ label) != DCRYPTO_OK)
return 0;
break;
case PADDING_MODE_PKCS1:
- if (!pkcs1_type2_pad((uint8_t *) padded.d, bn_size(&padded),
- (const uint8_t *) in, in_len))
+ if (pkcs1_type2_pad((uint8_t *)padded.d, bn_size(&padded),
+ (const uint8_t *)in, in_len) != DCRYPTO_OK)
return 0;
break;
case PADDING_MODE_NULL:
@@ -579,14 +618,15 @@ int DCRYPTO_rsa_decrypt(struct RSA *rsa, uint8_t *out, size_t *out_len,
switch (padding) {
case PADDING_MODE_OAEP:
- if (!check_oaep_pad(out, out_len, (uint8_t *) padded.d,
- bn_size(&padded), hashing, label))
+ if (check_oaep_pad(out, out_len, (uint8_t *)padded.d,
+ bn_size(&padded), hashing,
+ label) != DCRYPTO_OK)
ret = 0;
break;
case PADDING_MODE_PKCS1:
- if (!check_pkcs1_type2_pad(
- out, out_len, (const uint8_t *) padded.d,
- bn_size(&padded)))
+ if (check_pkcs1_type2_pad(out, out_len,
+ (const uint8_t *)padded.d,
+ bn_size(&padded)) != DCRYPTO_OK)
ret = 0;
break;
case PADDING_MODE_NULL:
@@ -626,13 +666,15 @@ int DCRYPTO_rsa_sign(struct RSA *rsa, uint8_t *out, size_t *out_len,
switch (padding) {
case PADDING_MODE_PKCS1:
- if (!pkcs1_type1_pad((uint8_t *) padded.d, bn_size(&padded),
- (const uint8_t *) in, in_len, hashing))
+ if (pkcs1_type1_pad((uint8_t *)padded.d, bn_size(&padded),
+ (const uint8_t *)in, in_len,
+ hashing) != DCRYPTO_OK)
return 0;
break;
case PADDING_MODE_PSS:
- if (!pkcs1_pss_pad((uint8_t *) padded.d, bn_size(&padded),
- (const uint8_t *) in, in_len, hashing))
+ if (pkcs1_pss_pad((uint8_t *)padded.d, bn_size(&padded),
+ (const uint8_t *)in, in_len,
+ hashing) != DCRYPTO_OK)
return 0;
break;
default:
@@ -679,15 +721,15 @@ int DCRYPTO_rsa_verify(const struct RSA *rsa, const uint8_t *digest,
switch (padding) {
case PADDING_MODE_PKCS1:
- if (!check_pkcs1_type1_pad(
- digest, digest_len, (uint8_t *) padded.d,
- bn_size(&padded), hashing))
+ if (check_pkcs1_type1_pad(digest, digest_len,
+ (uint8_t *)padded.d, bn_size(&padded),
+ hashing) != DCRYPTO_OK)
ret = 0;
break;
case PADDING_MODE_PSS:
- if (!check_pkcs1_pss_pad(
- digest, digest_len, (uint8_t *) padded.d,
- bn_size(&padded), hashing))
+ if (check_pkcs1_pss_pad(digest, digest_len, (uint8_t *)padded.d,
+ bn_size(&padded),
+ hashing) != DCRYPTO_OK)
ret = 0;
break;
default: