diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/rsa/rsa-sign.c | 60 | ||||
-rw-r--r-- | lib/rsa/rsa-verify.c | 93 |
2 files changed, 145 insertions, 8 deletions
diff --git a/lib/rsa/rsa-sign.c b/lib/rsa/rsa-sign.c index 83f5e87838..f4d433867a 100644 --- a/lib/rsa/rsa-sign.c +++ b/lib/rsa/rsa-sign.c @@ -261,10 +261,57 @@ err_priv: } /* + * rsa_get_exponent(): - Get the public exponent from an RSA key + */ +static int rsa_get_exponent(RSA *key, uint64_t *e) +{ + int ret; + BIGNUM *bn_te; + uint64_t te; + + ret = -EINVAL; + bn_te = NULL; + + if (!e) + goto cleanup; + + if (BN_num_bits(key->e) > 64) + goto cleanup; + + *e = BN_get_word(key->e); + + if (BN_num_bits(key->e) < 33) { + ret = 0; + goto cleanup; + } + + bn_te = BN_dup(key->e); + if (!bn_te) + goto cleanup; + + if (!BN_rshift(bn_te, bn_te, 32)) + goto cleanup; + + if (!BN_mask_bits(bn_te, 32)) + goto cleanup; + + te = BN_get_word(bn_te); + te <<= 32; + *e |= te; + ret = 0; + +cleanup: + if (bn_te) + BN_free(bn_te); + + return ret; +} + +/* * rsa_get_params(): - Get the important parameters of an RSA public key */ -int rsa_get_params(RSA *key, uint32_t *n0_invp, BIGNUM **modulusp, - BIGNUM **r_squaredp) +int rsa_get_params(RSA *key, uint64_t *exponent, uint32_t *n0_invp, + BIGNUM **modulusp, BIGNUM **r_squaredp) { BIGNUM *big1, *big2, *big32, *big2_32; BIGNUM *n, *r, *r_squared, *tmp; @@ -286,6 +333,9 @@ int rsa_get_params(RSA *key, uint32_t *n0_invp, BIGNUM **modulusp, return -ENOMEM; } + if (0 != rsa_get_exponent(key, exponent)) + ret = -1; + if (!BN_copy(n, key->n) || !BN_set_word(big1, 1L) || !BN_set_word(big2, 2L) || !BN_set_word(big32, 32L)) ret = -1; @@ -386,6 +436,7 @@ static int fdt_add_bignum(void *blob, int noffset, const char *prop_name, int rsa_add_verify_data(struct image_sign_info *info, void *keydest) { BIGNUM *modulus, *r_squared; + uint64_t exponent; uint32_t n0_inv; int parent, node; char name[100]; @@ -397,7 +448,7 @@ int rsa_add_verify_data(struct image_sign_info *info, void *keydest) ret = rsa_get_pub_key(info->keydir, info->keyname, &rsa); if (ret) return ret; - ret = rsa_get_params(rsa, &n0_inv, &modulus, &r_squared); + ret = rsa_get_params(rsa, &exponent, &n0_inv, &modulus, &r_squared); if (ret) return ret; bits = BN_num_bits(modulus); @@ -442,6 +493,9 @@ int rsa_add_verify_data(struct image_sign_info *info, void *keydest) if (!ret) ret = fdt_setprop_u32(keydest, node, "rsa,n0-inverse", n0_inv); if (!ret) { + ret = fdt_setprop_u64(keydest, node, "rsa,exponent", exponent); + } + if (!ret) { ret = fdt_add_bignum(keydest, node, "rsa,modulus", modulus, bits); } diff --git a/lib/rsa/rsa-verify.c b/lib/rsa/rsa-verify.c index bcb906368d..c5bcdb60e8 100644 --- a/lib/rsa/rsa-verify.c +++ b/lib/rsa/rsa-verify.c @@ -26,6 +26,9 @@ #define get_unaligned_be32(a) fdt32_to_cpu(*(uint32_t *)a) #define put_unaligned_be32(a, b) (*(uint32_t *)(b) = cpu_to_fdt32(a)) +/* Default public exponent for backward compatibility */ +#define RSA_DEFAULT_PUBEXP 65537 + /** * subtract_modulus() - subtract modulus from the given value * @@ -123,6 +126,48 @@ static void montgomery_mul(const struct rsa_public_key *key, } /** + * num_pub_exponent_bits() - Number of bits in the public exponent + * + * @key: RSA key + * @num_bits: Storage for the number of public exponent bits + */ +static int num_public_exponent_bits(const struct rsa_public_key *key, + int *num_bits) +{ + uint64_t exponent; + int exponent_bits; + const uint max_bits = (sizeof(exponent) * 8); + + exponent = key->exponent; + exponent_bits = 0; + + if (!exponent) { + *num_bits = exponent_bits; + return 0; + } + + for (exponent_bits = 1; exponent_bits < max_bits + 1; ++exponent_bits) + if (!(exponent >>= 1)) { + *num_bits = exponent_bits; + return 0; + } + + return -EINVAL; +} + +/** + * is_public_exponent_bit_set() - Check if a bit in the public exponent is set + * + * @key: RSA key + * @pos: The bit position to check + */ +static int is_public_exponent_bit_set(const struct rsa_public_key *key, + int pos) +{ + return key->exponent & (1ULL << pos); +} + +/** * pow_mod() - in-place public exponentiation * * @key: RSA key @@ -132,6 +177,7 @@ static int pow_mod(const struct rsa_public_key *key, uint32_t *inout) { uint32_t *result, *ptr; uint i; + int j, k; /* Sanity check for stack size - key->len is in 32-bit words */ if (key->len > RSA_MAX_KEY_BITS / 32) { @@ -141,18 +187,48 @@ static int pow_mod(const struct rsa_public_key *key, uint32_t *inout) } uint32_t val[key->len], acc[key->len], tmp[key->len]; + uint32_t a_scaled[key->len]; result = tmp; /* Re-use location. */ /* Convert from big endian byte array to little endian word array. */ for (i = 0, ptr = inout + key->len - 1; i < key->len; i++, ptr--) val[i] = get_unaligned_be32(ptr); - montgomery_mul(key, acc, val, key->rr); /* axx = a * RR / R mod M */ - for (i = 0; i < 16; i += 2) { - montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod M */ - montgomery_mul(key, acc, tmp, tmp); /* acc = tmp^2 / R mod M */ + if (0 != num_public_exponent_bits(key, &k)) + return -EINVAL; + + if (k < 2) { + debug("Public exponent is too short (%d bits, minimum 2)\n", + k); + return -EINVAL; } - montgomery_mul(key, result, acc, val); /* result = XX * a / R mod M */ + + if (!is_public_exponent_bit_set(key, 0)) { + debug("LSB of RSA public exponent must be set.\n"); + return -EINVAL; + } + + /* the bit at e[k-1] is 1 by definition, so start with: C := M */ + montgomery_mul(key, acc, val, key->rr); /* acc = a * RR / R mod n */ + /* retain scaled version for intermediate use */ + memcpy(a_scaled, acc, key->len * sizeof(a_scaled[0])); + + for (j = k - 2; j > 0; --j) { + montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */ + + if (is_public_exponent_bit_set(key, j)) { + /* acc = tmp * val / R mod n */ + montgomery_mul(key, acc, tmp, a_scaled); + } else { + /* e[j] == 0, copy tmp back to acc for next operation */ + memcpy(acc, tmp, key->len * sizeof(acc[0])); + } + } + + /* the bit at e[0] is always 1 */ + montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */ + montgomery_mul(key, acc, tmp, val); /* acc = tmp * a / R mod M */ + memcpy(result, acc, key->len * sizeof(result[0])); /* Make sure result < mod; result is at most 1x mod too large. */ if (greater_equal_modulus(key, result)) @@ -229,6 +305,8 @@ static int rsa_verify_with_keynode(struct image_sign_info *info, const void *blob = info->fdt_blob; struct rsa_public_key key; const void *modulus, *rr; + const uint64_t *public_exponent; + int length; int ret; if (node < 0) { @@ -241,6 +319,11 @@ static int rsa_verify_with_keynode(struct image_sign_info *info, } key.len = fdtdec_get_int(blob, node, "rsa,num-bits", 0); key.n0inv = fdtdec_get_int(blob, node, "rsa,n0-inverse", 0); + public_exponent = fdt_getprop(blob, node, "rsa,exponent", &length); + if (!public_exponent || length < sizeof(*public_exponent)) + key.exponent = RSA_DEFAULT_PUBEXP; + else + key.exponent = fdt64_to_cpu(*public_exponent); modulus = fdt_getprop(blob, node, "rsa,modulus", NULL); rr = fdt_getprop(blob, node, "rsa,r-squared", NULL); if (!key.len || !modulus || !rr) { |