diff options
Diffstat (limited to 'board/cr50/dcrypto/rsa.c')
-rw-r--r-- | board/cr50/dcrypto/rsa.c | 302 |
1 files changed, 172 insertions, 130 deletions
diff --git a/board/cr50/dcrypto/rsa.c b/board/cr50/dcrypto/rsa.c index 9c721abc5e..524d3d14f4 100644 --- a/board/cr50/dcrypto/rsa.c +++ b/board/cr50/dcrypto/rsa.c @@ -31,9 +31,9 @@ static uint32_t select(uint32_t mask, uint32_t a, uint32_t b) /* We use SHA256 context to store SHA1 context, so make sure it's ok. */ BUILD_ASSERT(sizeof(struct sha256_ctx) >= sizeof(struct sha1_ctx)); -static void MGF1_xor(uint8_t *dst, uint32_t dst_len, - const uint8_t *seed, uint32_t seed_len, - enum hashing_mode hashing) +static enum dcrypto_result MGF1_xor(uint8_t *dst, uint32_t dst_len, + const uint8_t *seed, uint32_t seed_len, + enum hashing_mode hashing) { union hash_ctx ctx; @@ -44,17 +44,19 @@ static void MGF1_xor(uint8_t *dst, uint32_t dst_len, uint8_t b0; } cnt; const uint8_t *digest; - const size_t hash_size = (hashing == HASH_SHA1) ? SHA1_DIGEST_SIZE : - SHA256_DIGEST_SIZE; + const size_t hash_size = DCRYPTO_hash_size(hashing); + + if (!hash_size) + return DCRYPTO_FAIL; cnt.b0 = cnt.b1 = cnt.b2 = cnt.b3 = 0; while (dst_len) { size_t i; + enum dcrypto_result result; - if (hashing == HASH_SHA1) - SHA1_hw_init(&ctx.sha1); - else - SHA256_hw_init(&ctx.sha256); + result = DCRYPTO_hw_hash_init(&ctx, hashing); + if (result != DCRYPTO_OK) + return result; HASH_update(&ctx, seed, seed_len); HASH_update(&ctx, (uint8_t *)&cnt, sizeof(cnt)); @@ -65,6 +67,7 @@ static void MGF1_xor(uint8_t *dst, uint32_t dst_len, if (!++cnt.b0) ++cnt.b1; } + return DCRYPTO_OK; } /* @@ -78,12 +81,12 @@ static void MGF1_xor(uint8_t *dst, uint32_t dst_len, * }; */ /* encrypt */ -static int oaep_pad(uint8_t *output, uint32_t output_len, - const uint8_t *msg, uint32_t msg_len, - enum hashing_mode hashing, const char *label) +static enum dcrypto_result oaep_pad(uint8_t *output, uint32_t output_len, + const uint8_t *msg, uint32_t msg_len, + enum hashing_mode hashing, + const char *label) { - const size_t hash_size = (hashing == HASH_SHA1) ? SHA_DIGEST_SIZE - : SHA256_DIGEST_SIZE; + const size_t hash_size = DCRYPTO_hash_size(hashing); uint8_t *const seed = output + 1; uint8_t *const phash = seed + hash_size; uint8_t *const PS = phash + hash_size; @@ -91,39 +94,44 @@ static int oaep_pad(uint8_t *output, uint32_t output_len, const uint32_t ps_len = max_msg_len - msg_len; uint8_t *const one = PS + ps_len; union hash_ctx ctx; + enum dcrypto_result result; + if (!hash_size) + return DCRYPTO_FAIL; if (output_len < 2 + 2 * hash_size) - return 0; /* Key size too small for chosen hash. */ + return DCRYPTO_FAIL; /* Key size too small for chosen hash. */ if (msg_len > output_len - 2 - 2 * hash_size) - return 0; /* Input message too large for key size. */ + return DCRYPTO_FAIL; /* Input message too large for key size. */ always_memset(output, 0, output_len); if (!fips_rand_bytes(seed, hash_size)) - return 0; + return DCRYPTO_FAIL; - if (hashing == HASH_SHA1) - SHA1_hw_init(&ctx.sha1); - else - SHA256_hw_init(&ctx.sha256); + result = DCRYPTO_hw_hash_init(&ctx, hashing); + if (result != DCRYPTO_OK) + return result; HASH_update(&ctx, label, label ? strlen(label) + 1 : 0); memcpy(phash, HASH_final(&ctx)->b8, hash_size); *one = 1; memcpy(one + 1, msg, msg_len); - MGF1_xor(phash, hash_size + 1 + max_msg_len, - seed, hash_size, hashing); - MGF1_xor(seed, hash_size, phash, hash_size + 1 + max_msg_len, - hashing); - return 1; + result = MGF1_xor(phash, hash_size + 1 + max_msg_len, seed, hash_size, + hashing); + result |= MGF1_xor(seed, hash_size, phash, hash_size + 1 + max_msg_len, + hashing); + + if (result != DCRYPTO_OK) + return DCRYPTO_FAIL; + + return result; } /* decrypt */ -static int check_oaep_pad(uint8_t *out, size_t *out_len, +static enum dcrypto_result check_oaep_pad(uint8_t *out, size_t *out_len, uint8_t *padded, size_t padded_len, enum hashing_mode hashing, const char *label) { - const size_t hash_size = (hashing == HASH_SHA1) ? SHA_DIGEST_SIZE - : SHA256_DIGEST_SIZE; + const size_t hash_size = DCRYPTO_hash_size(hashing); uint8_t *seed = padded + 1; uint8_t *phash = seed + hash_size; uint8_t *PS = phash + hash_size; @@ -133,27 +141,35 @@ static int check_oaep_pad(uint8_t *out, size_t *out_len, uint32_t looking_for_one_byte = ~0; int bad; size_t i; + enum dcrypto_result result; + + if (!hash_size) + return DCRYPTO_FAIL; if (padded_len < 2 + 2 * hash_size) - return 0; /* Invalid input size. */ + return DCRYPTO_FAIL; /* Invalid input size. */ /* Recover seed. */ - MGF1_xor(seed, hash_size, phash, hash_size + 1 + max_msg_len, hashing); + result = MGF1_xor(seed, hash_size, phash, hash_size + 1 + max_msg_len, + hashing); /* Recover db. */ - MGF1_xor(phash, hash_size + 1 + max_msg_len, seed, hash_size, hashing); + result |= MGF1_xor(phash, hash_size + 1 + max_msg_len, seed, hash_size, + hashing); + + if (result != DCRYPTO_OK) + return DCRYPTO_FAIL; - if (hashing == HASH_SHA1) - SHA1_hw_init(&ctx.sha1); - else - SHA256_hw_init(&ctx.sha256); + result = DCRYPTO_hw_hash_init(&ctx, hashing); + if (result != DCRYPTO_OK) + return result; HASH_update(&ctx, label, label ? strlen(label) + 1 : 0); - /* bad should be zero if CRYPTO_OK is returned. */ - bad = DCRYPTO_equals(phash, HASH_final(&ctx)->b8, hash_size) - - DCRYPTO_OK; + /* bad should be zero if DCRYPTO_OK is returned. */ + result = DCRYPTO_equals(phash, HASH_final(&ctx)->b8, hash_size); + bad = result - DCRYPTO_OK; /* bad = 0 if result == DCRYPTO_OK */ bad |= padded[0]; - for (i = PS - padded; i < padded_len; i++) { + for (i = PS - padded; i < padded_len; i++) { uint32_t equals0 = is_zero(padded[i]); uint32_t equals1 = is_zero(padded[i] ^ 1); @@ -168,29 +184,30 @@ static int check_oaep_pad(uint8_t *out, size_t *out_len, bad |= looking_for_one_byte; if (bad) - return 0; + return DCRYPTO_FAIL; one_index++; if (*out_len < padded_len - one_index) - return 0; + return DCRYPTO_FAIL; memcpy(out, padded + one_index, padded_len - one_index); *out_len = padded_len - one_index; - return 1; + /* Result should be DCRYPTO_OK after DCRYPTO_equals() */ + return result; } /* Constants from RFC 3447. */ #define RSA_PKCS1_PADDING_SIZE 11 /* encrypt */ -static int pkcs1_type2_pad(uint8_t *padded, size_t padded_len, - const uint8_t *in, size_t in_len) +static enum dcrypto_result pkcs1_type2_pad(uint8_t *padded, size_t padded_len, + const uint8_t *in, size_t in_len) { size_t PS_len; if (padded_len < RSA_PKCS1_PADDING_SIZE) - return 0; + return DCRYPTO_FAIL; if (in_len > padded_len - RSA_PKCS1_PADDING_SIZE) - return 0; + return DCRYPTO_FAIL; PS_len = padded_len - 3 - in_len; *(padded++) = 0; @@ -200,7 +217,7 @@ static int pkcs1_type2_pad(uint8_t *padded, size_t padded_len, uint8_t r[SHA256_DIGEST_SIZE]; if (!fips_rand_bytes(r, sizeof(r))) - return 0; + return DCRYPTO_FAIL; /** * zero byte has special meaning in PKCS1, so copy @@ -215,11 +232,11 @@ static int pkcs1_type2_pad(uint8_t *padded, size_t padded_len, } *(padded++) = 0; memcpy(padded, in, in_len); - return 1; + return DCRYPTO_OK; } /* decrypt */ -static int check_pkcs1_type2_pad(uint8_t *out, size_t *out_len, +static enum dcrypto_result check_pkcs1_type2_pad(uint8_t *out, size_t *out_len, const uint8_t *padded, size_t padded_len) { size_t i; @@ -228,7 +245,7 @@ static int check_pkcs1_type2_pad(uint8_t *out, size_t *out_len, uint32_t looking_for_index = ~0; if (padded_len < RSA_PKCS1_PADDING_SIZE) - return 0; + return DCRYPTO_FAIL; valid = (padded[0] == 0); valid &= (padded[1] == 2); @@ -245,13 +262,14 @@ static int check_pkcs1_type2_pad(uint8_t *out, size_t *out_len, valid &= ~looking_for_index; valid &= (zero_index >= RSA_PKCS1_PADDING_SIZE); if (!valid) - return 0; + return DCRYPTO_FAIL; if (*out_len < padded_len - zero_index) - return 0; + return DCRYPTO_FAIL; + memcpy(out, &padded[zero_index], padded_len - zero_index); *out_len = padded_len - zero_index; - return 1; + return DCRYPTO_OK; } static const uint8_t SHA1_DER[] = { @@ -274,8 +292,9 @@ static const uint8_t SHA512_DER[] = { 0x00, 0x04, 0x40 }; -static int pkcs1_get_der(enum hashing_mode hashing, const uint8_t **der, - size_t *der_size, size_t *hash_size) +static enum dcrypto_result pkcs1_get_der(enum hashing_mode hashing, + const uint8_t **der, size_t *der_size, + size_t *hash_size) { switch (hashing) { case HASH_SHA1: @@ -301,33 +320,35 @@ static int pkcs1_get_der(enum hashing_mode hashing, const uint8_t **der, case HASH_NULL: *der = NULL; *der_size = 0; - *hash_size = 0; /* any size allowed */ + *hash_size = 0; /* any size allowed */ break; default: - return 0; + return DCRYPTO_FAIL; } - return 1; + return DCRYPTO_OK; } /* sign */ -static int pkcs1_type1_pad(uint8_t *padded, size_t padded_len, - const uint8_t *in, size_t in_len, - enum hashing_mode hashing) +static enum dcrypto_result pkcs1_type1_pad(uint8_t *padded, size_t padded_len, + const uint8_t *in, size_t in_len, + enum hashing_mode hashing) { const uint8_t *der; size_t der_size; size_t hash_size; size_t ps_len; + enum dcrypto_result result; - if (!pkcs1_get_der(hashing, &der, &der_size, &hash_size)) - return 0; + result = pkcs1_get_der(hashing, &der, &der_size, &hash_size); + if (result != DCRYPTO_OK) + return result; if (padded_len < RSA_PKCS1_PADDING_SIZE + der_size) - return 0; + return DCRYPTO_FAIL; if (!in_len || (hash_size && in_len != hash_size)) - return 0; + return DCRYPTO_FAIL; if (in_len > padded_len - RSA_PKCS1_PADDING_SIZE - der_size) - return 0; + return DCRYPTO_FAIL; ps_len = padded_len - 3 - der_size - in_len; *(padded++) = 0; @@ -338,73 +359,83 @@ static int pkcs1_type1_pad(uint8_t *padded, size_t padded_len, memcpy(padded, der, der_size); padded += der_size; memcpy(padded, in, in_len); - return 1; + return DCRYPTO_OK; } /* verify */ -static int check_pkcs1_type1_pad(const uint8_t *msg, size_t msg_len, - const uint8_t *padded, size_t padded_len, - enum hashing_mode hashing) +static enum dcrypto_result check_pkcs1_type1_pad(const uint8_t *msg, + size_t msg_len, + const uint8_t *padded, + size_t padded_len, + enum hashing_mode hashing) { size_t i; const uint8_t *der; size_t der_size; size_t hash_size; size_t ps_len; + enum dcrypto_result result; - if (!pkcs1_get_der(hashing, &der, &der_size, &hash_size)) - return 0; + result = pkcs1_get_der(hashing, &der, &der_size, &hash_size); + if (result != DCRYPTO_OK) + return result; if (msg_len != hash_size) - return 0; + return DCRYPTO_FAIL; if (padded_len < RSA_PKCS1_PADDING_SIZE + der_size + hash_size) - return 0; + return DCRYPTO_FAIL; ps_len = padded_len - 3 - der_size - hash_size; if (padded[0] != 0 || padded[1] != 1) - return 0; + return DCRYPTO_FAIL; for (i = 2; i < ps_len + 2; i++) { if (padded[i] != 0xFF) - return 0; + return DCRYPTO_FAIL; } if (padded[i++] != 0) - return 0; - if (DCRYPTO_equals(&padded[i], der, der_size) != DCRYPTO_OK) - return 0; + return DCRYPTO_FAIL; + + result = DCRYPTO_equals(&padded[i], der, der_size); i += der_size; - return DCRYPTO_equals(msg, &padded[i], hash_size) == DCRYPTO_OK; + result |= DCRYPTO_equals(msg, &padded[i], hash_size); + if (result != DCRYPTO_OK) + return DCRYPTO_FAIL; + return result; } /* sign */ -static int pkcs1_pss_pad(uint8_t *padded, size_t padded_len, - const uint8_t *in, size_t in_len, - enum hashing_mode hashing) +static enum dcrypto_result pkcs1_pss_pad(uint8_t *padded, size_t padded_len, + const uint8_t *in, size_t in_len, + enum hashing_mode hashing) { - const uint32_t hash_size = (hashing == HASH_SHA1) ? SHA1_DIGEST_SIZE - : SHA256_DIGEST_SIZE; + const uint32_t hash_size = DCRYPTO_hash_size(hashing); const uint32_t salt_len = MIN(padded_len - hash_size - 2, hash_size); size_t db_len; size_t ps_len; union hash_ctx ctx; + enum dcrypto_result result; + if (!hash_size) + return DCRYPTO_FAIL; if (in_len != hash_size) - return 0; + return DCRYPTO_FAIL; if (padded_len < hash_size + 2) - return 0; + return DCRYPTO_FAIL; db_len = padded_len - hash_size - 1; - if (hashing == HASH_SHA1) - SHA1_hw_init(&ctx.sha1); - else - SHA256_hw_init(&ctx.sha256); + result = DCRYPTO_hw_hash_init(&ctx, hashing); + if (result != DCRYPTO_OK) + return result; /* Pilfer bits of output for temporary use. */ memset(padded, 0, 8); HASH_update(&ctx, padded, 8); HASH_update(&ctx, in, in_len); /* Pilfer bits of output for temporary use. */ - if (!fips_rand_bytes(padded, salt_len)) - return 0; + if (!fips_rand_bytes(padded, salt_len)) { + HASH_final(&ctx); /* free up SHA engine */ + return DCRYPTO_FAIL; + } HASH_update(&ctx, padded, salt_len); /* Output hash. */ @@ -415,22 +446,22 @@ static int pkcs1_pss_pad(uint8_t *padded, size_t padded_len, memmove(padded + ps_len + 1, padded, salt_len); memset(padded, 0, ps_len); padded[ps_len] = 0x01; - MGF1_xor(padded, db_len, padded + db_len, hash_size, hashing); + result = MGF1_xor(padded, db_len, padded + db_len, hash_size, hashing); /* Clear most significant bit. */ padded[0] &= 0x7F; /* Set trailing byte. */ padded[padded_len - 1] = 0xBC; - return 1; + return result; } /* verify */ -static int check_pkcs1_pss_pad(const uint8_t *in, size_t in_len, - uint8_t *padded, size_t padded_len, - enum hashing_mode hashing) +static enum dcrypto_result check_pkcs1_pss_pad(const uint8_t *in, size_t in_len, + uint8_t *padded, + size_t padded_len, + enum hashing_mode hashing) { - const uint32_t hash_size = (hashing == HASH_SHA1) ? SHA1_DIGEST_SIZE - : SHA256_DIGEST_SIZE; + const uint32_t hash_size = DCRYPTO_hash_size(hashing); const uint8_t zeros[8] = {0, 0, 0, 0, 0, 0, 0, 0}; uint32_t db_len; uint32_t max_ps_len; @@ -438,11 +469,14 @@ static int check_pkcs1_pss_pad(const uint8_t *in, size_t in_len, union hash_ctx ctx; int bad = 0; size_t i; + enum dcrypto_result result; + if (!hash_size) + return DCRYPTO_FAIL; if (in_len != hash_size) - return 0; + return DCRYPTO_FAIL; if (padded_len < hash_size + 2) - return 0; + return DCRYPTO_FAIL; db_len = padded_len - hash_size - 1; /* Top bit should be zero. */ @@ -451,7 +485,9 @@ static int check_pkcs1_pss_pad(const uint8_t *in, size_t in_len, bad |= padded[padded_len - 1] ^ 0xBC; /* Recover DB. */ - MGF1_xor(padded, db_len, padded + db_len, hash_size, hashing); + result = MGF1_xor(padded, db_len, padded + db_len, hash_size, hashing); + bad |= result - DCRYPTO_OK; + /* Clear top bit. */ padded[0] &= 0x7F; /* Verify padding2. */ @@ -466,16 +502,18 @@ static int check_pkcs1_pss_pad(const uint8_t *in, size_t in_len, /* Continue with zero-length salt if 0x01 was not found. */ salt_len = max_ps_len - i; - if (hashing == HASH_SHA1) - SHA1_hw_init(&ctx.sha1); - else - SHA256_hw_init(&ctx.sha256); + result |= DCRYPTO_hw_hash_init(&ctx, hashing); + if (result != DCRYPTO_OK) + return DCRYPTO_FAIL; + HASH_update(&ctx, zeros, sizeof(zeros)); HASH_update(&ctx, in, in_len); HASH_update(&ctx, padded + db_len - salt_len, salt_len); - bad |= DCRYPTO_equals(padded + db_len, HASH_final(&ctx), hash_size) - - DCRYPTO_OK; - return !bad; + result |= DCRYPTO_equals(padded + db_len, HASH_final(&ctx), hash_size); + bad |= result - DCRYPTO_OK; + if (bad) + result = DCRYPTO_FAIL; + return result; } static int check_modulus_params( @@ -511,13 +549,14 @@ int DCRYPTO_rsa_encrypt(struct RSA *rsa, uint8_t *out, size_t *out_len, switch (padding) { case PADDING_MODE_OAEP: - if (!oaep_pad((uint8_t *) padded.d, bn_size(&padded), - (const uint8_t *) in, in_len, hashing, label)) + if (oaep_pad((uint8_t *)padded.d, bn_size(&padded), + (const uint8_t *)in, in_len, hashing, + label) != DCRYPTO_OK) return 0; break; case PADDING_MODE_PKCS1: - if (!pkcs1_type2_pad((uint8_t *) padded.d, bn_size(&padded), - (const uint8_t *) in, in_len)) + if (pkcs1_type2_pad((uint8_t *)padded.d, bn_size(&padded), + (const uint8_t *)in, in_len) != DCRYPTO_OK) return 0; break; case PADDING_MODE_NULL: @@ -579,14 +618,15 @@ int DCRYPTO_rsa_decrypt(struct RSA *rsa, uint8_t *out, size_t *out_len, switch (padding) { case PADDING_MODE_OAEP: - if (!check_oaep_pad(out, out_len, (uint8_t *) padded.d, - bn_size(&padded), hashing, label)) + if (check_oaep_pad(out, out_len, (uint8_t *)padded.d, + bn_size(&padded), hashing, + label) != DCRYPTO_OK) ret = 0; break; case PADDING_MODE_PKCS1: - if (!check_pkcs1_type2_pad( - out, out_len, (const uint8_t *) padded.d, - bn_size(&padded))) + if (check_pkcs1_type2_pad(out, out_len, + (const uint8_t *)padded.d, + bn_size(&padded)) != DCRYPTO_OK) ret = 0; break; case PADDING_MODE_NULL: @@ -626,13 +666,15 @@ int DCRYPTO_rsa_sign(struct RSA *rsa, uint8_t *out, size_t *out_len, switch (padding) { case PADDING_MODE_PKCS1: - if (!pkcs1_type1_pad((uint8_t *) padded.d, bn_size(&padded), - (const uint8_t *) in, in_len, hashing)) + if (pkcs1_type1_pad((uint8_t *)padded.d, bn_size(&padded), + (const uint8_t *)in, in_len, + hashing) != DCRYPTO_OK) return 0; break; case PADDING_MODE_PSS: - if (!pkcs1_pss_pad((uint8_t *) padded.d, bn_size(&padded), - (const uint8_t *) in, in_len, hashing)) + if (pkcs1_pss_pad((uint8_t *)padded.d, bn_size(&padded), + (const uint8_t *)in, in_len, + hashing) != DCRYPTO_OK) return 0; break; default: @@ -679,15 +721,15 @@ int DCRYPTO_rsa_verify(const struct RSA *rsa, const uint8_t *digest, switch (padding) { case PADDING_MODE_PKCS1: - if (!check_pkcs1_type1_pad( - digest, digest_len, (uint8_t *) padded.d, - bn_size(&padded), hashing)) + if (check_pkcs1_type1_pad(digest, digest_len, + (uint8_t *)padded.d, bn_size(&padded), + hashing) != DCRYPTO_OK) ret = 0; break; case PADDING_MODE_PSS: - if (!check_pkcs1_pss_pad( - digest, digest_len, (uint8_t *) padded.d, - bn_size(&padded), hashing)) + if (check_pkcs1_pss_pad(digest, digest_len, (uint8_t *)padded.d, + bn_size(&padded), + hashing) != DCRYPTO_OK) ret = 0; break; default: |