summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSohaib ul Hassan <sohaibulhassan@tuni.fi>2020-06-16 15:40:57 -0700
committerSohaib ul Hassan <sohaibulhassan@tuni.fi>2020-06-16 15:40:57 -0700
commitad79bd8695b4d08d8fcf71e41a6c1b0275235ddc (patch)
tree665a3d55b55bd5f7a3108026e0c2cf5d143f67fe
parentdd951fb012b9cbac75a46f795ed46cdf6a57aa46 (diff)
downloadnss-hg-ad79bd8695b4d08d8fcf71e41a6c1b0275235ddc.tar.gz
Bug 1631597 - Constant-time GCD and modular inversion r=rrelyea,kjacobs
The implementation is based on the work by Bernstein and Yang (https://eprint.iacr.org/2019/266) "Fast constant-time gcd computation and modular inversion". It fixes the old mp_gcd and s_mp_invmod_odd_m functions. The patch also fix mpl_significant_bits s_mp_div_2d and s_mp_mul_2d by having less control flow to reduce side-channel leaks. Co Author : Billy Bob Brumley Differential Revision: https://phabricator.services.mozilla.com/D78668
-rw-r--r--lib/freebl/mpi/mpi.c378
-rw-r--r--lib/freebl/mpi/mpi.h1
-rw-r--r--lib/freebl/mpi/mplogic.c45
3 files changed, 292 insertions, 132 deletions
diff --git a/lib/freebl/mpi/mpi.c b/lib/freebl/mpi/mpi.c
index 7e96e51ff..1b7b171e7 100644
--- a/lib/freebl/mpi/mpi.c
+++ b/lib/freebl/mpi/mpi.c
@@ -8,6 +8,7 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "mpi-priv.h"
+#include "mplogic.h"
#if defined(OSF1)
#include <c_asm.h>
#endif
@@ -1688,98 +1689,112 @@ mp_iseven(const mp_int *a)
/* {{{ mp_gcd(a, b, c) */
/*
- Like the old mp_gcd() function, except computes the GCD using the
- binary algorithm due to Josef Stein in 1961 (via Knuth).
+ Computes the GCD using the constant-time algorithm
+ by Bernstein and Yang (https://eprint.iacr.org/2019/266)
+ "Fast constant-time gcd computation and modular inversion"
*/
mp_err
mp_gcd(mp_int *a, mp_int *b, mp_int *c)
{
mp_err res;
- mp_int u, v, t;
- mp_size k = 0;
+ mp_digit cond = 0, mask = 0;
+ mp_int g, temp, f;
+ int i, j, m, bit = 1, delta = 1, shifts = 0, last = -1;
+ mp_size top, flen, glen;
+ mp_int *clear[3];
ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG);
-
- if (mp_cmp_z(a) == MP_EQ && mp_cmp_z(b) == MP_EQ)
- return MP_RANGE;
+ /*
+ Early exit if either of the inputs is zero.
+ Caller is responsible for the proper handling of inputs.
+ */
if (mp_cmp_z(a) == MP_EQ) {
- return mp_copy(b, c);
+ res = mp_copy(b, c);
+ SIGN(c) = ZPOS;
+ return res;
} else if (mp_cmp_z(b) == MP_EQ) {
- return mp_copy(a, c);
- }
-
- if ((res = mp_init(&t)) != MP_OKAY)
+ res = mp_copy(a, c);
+ SIGN(c) = ZPOS;
return res;
- if ((res = mp_init_copy(&u, a)) != MP_OKAY)
- goto U;
- if ((res = mp_init_copy(&v, b)) != MP_OKAY)
- goto V;
-
- SIGN(&u) = ZPOS;
- SIGN(&v) = ZPOS;
-
- /* Divide out common factors of 2 until at least 1 of a, b is even */
- while (mp_iseven(&u) && mp_iseven(&v)) {
- s_mp_div_2(&u);
- s_mp_div_2(&v);
- ++k;
}
- /* Initialize t */
- if (mp_isodd(&u)) {
- if ((res = mp_copy(&v, &t)) != MP_OKAY)
- goto CLEANUP;
-
- /* t = -v */
- if (SIGN(&v) == ZPOS)
- SIGN(&t) = NEG;
- else
- SIGN(&t) = ZPOS;
+ MP_CHECKOK(mp_init(&temp));
+ clear[++last] = &temp;
+ MP_CHECKOK(mp_init_copy(&g, a));
+ clear[++last] = &g;
+ MP_CHECKOK(mp_init_copy(&f, b));
+ clear[++last] = &f;
- } else {
- if ((res = mp_copy(&u, &t)) != MP_OKAY)
- goto CLEANUP;
+ /*
+ For even case compute the number of
+ shared powers of 2 in f and g.
+ */
+ for (i = 0; i < USED(&f) && i < USED(&g); i++) {
+ mask = ~(DIGIT(&f, i) | DIGIT(&g, i));
+ for (j = 0; j < MP_DIGIT_BIT; j++) {
+ bit &= mask;
+ shifts += bit;
+ mask >>= 1;
+ }
}
+ /* Reduce to the odd case by removing the powers of 2. */
+ s_mp_div_2d(&f, shifts);
+ s_mp_div_2d(&g, shifts);
- for (;;) {
- while (mp_iseven(&t)) {
- s_mp_div_2(&t);
- }
+ /* Allocate to the size of largest mp_int. */
+ top = (mp_size)1 + ((USED(&f) >= USED(&g)) ? USED(&f) : USED(&g));
+ MP_CHECKOK(s_mp_grow(&f, top));
+ MP_CHECKOK(s_mp_grow(&g, top));
+ MP_CHECKOK(s_mp_grow(&temp, top));
- if (mp_cmp_z(&t) == MP_GT) {
- if ((res = mp_copy(&t, &u)) != MP_OKAY)
- goto CLEANUP;
+ /* Make sure f contains the odd value. */
+ MP_CHECKOK(mp_cswap((~DIGIT(&f, 0) & 1), &f, &g, top));
- } else {
- if ((res = mp_copy(&t, &v)) != MP_OKAY)
- goto CLEANUP;
+ /* Upper bound for the total iterations. */
+ flen = mpl_significant_bits(&f);
+ glen = mpl_significant_bits(&g);
+ m = 4 + 3 * ((flen >= glen) ? flen : glen);
- /* v = -t */
- if (SIGN(&t) == ZPOS)
- SIGN(&v) = NEG;
- else
- SIGN(&v) = ZPOS;
- }
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
+#endif
- if ((res = mp_sub(&u, &v, &t)) != MP_OKAY)
- goto CLEANUP;
+ for (i = 0; i < m; i++) {
+ /* Step 1: conditional swap. */
+ /* Set cond if delta > 0 and g is odd. */
+ cond = (-delta >> (8 * sizeof(delta) - 1)) & DIGIT(&g, 0) & 1;
+ /* If cond is set replace (delta,f) with (-delta,-f). */
+ delta = (-cond & -delta) | ((cond - 1) & delta);
+ SIGN(&f) ^= cond;
+ /* If cond is set swap f with g. */
+ MP_CHECKOK(mp_cswap(cond, &f, &g, top));
+
+ /* Step 2: elemination. */
+ /* Update delta. */
+ delta++;
+ /* If g is odd, right shift (g+f) else right shift g. */
+ MP_CHECKOK(mp_add(&g, &f, &temp));
+ MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &g, &temp, top));
+ s_mp_div_2(&g);
+ }
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
- if (s_mp_cmp_d(&t, 0) == MP_EQ)
- break;
- }
+ /* GCD is in f, take the absolute value. */
+ SIGN(&f) = ZPOS;
- s_mp_2expt(&v, k); /* v = 2^k */
- res = mp_mul(&u, &v, c); /* c = u * v */
+ /* Add back the removed powers of 2. */
+ MP_CHECKOK(s_mp_mul_2d(&f, shifts));
-CLEANUP:
- mp_clear(&v);
-V:
- mp_clear(&u);
-U:
- mp_clear(&t);
+ MP_CHECKOK(mp_copy(&f, c));
+CLEANUP:
+ while (last >= 0)
+ mp_clear(clear[last--]);
return res;
-
} /* end mp_gcd() */
/* }}} */
@@ -2131,42 +2146,114 @@ CLEANUP:
return res;
}
-/* compute mod inverse using Schroeppel's method, only if m is odd */
+/*
+ Computes the modular inverse using the constant-time algorithm
+ by Bernstein and Yang (https://eprint.iacr.org/2019/266)
+ "Fast constant-time gcd computation and modular inversion"
+ */
mp_err
s_mp_invmod_odd_m(const mp_int *a, const mp_int *m, mp_int *c)
{
- int k;
mp_err res;
- mp_int x;
+ mp_digit cond = 0;
+ mp_int g, f, v, r, temp;
+ int i, its, delta = 1, last = -1;
+ mp_size top, flen, glen;
+ mp_int *clear[6];
ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
-
- if (mp_cmp_z(a) == 0 || mp_cmp_z(m) == 0)
+ /* Check for invalid inputs. */
+ if (mp_cmp_z(a) == MP_EQ || mp_cmp_d(m, 2) == MP_LT)
return MP_RANGE;
- if (mp_iseven(m))
+
+ if (a == m || mp_iseven(m))
return MP_UNDEF;
- MP_DIGITS(&x) = 0;
+ MP_CHECKOK(mp_init(&temp));
+ clear[++last] = &temp;
+ MP_CHECKOK(mp_init(&v));
+ clear[++last] = &v;
+ MP_CHECKOK(mp_init(&r));
+ clear[++last] = &r;
+ MP_CHECKOK(mp_init_copy(&g, a));
+ clear[++last] = &g;
+ MP_CHECKOK(mp_init_copy(&f, m));
+ clear[++last] = &f;
+
+ mp_set(&v, 0);
+ mp_set(&r, 1);
+
+ /* Allocate to the size of largest mp_int. */
+ top = (mp_size)1 + ((USED(&f) >= USED(&g)) ? USED(&f) : USED(&g));
+ MP_CHECKOK(s_mp_grow(&f, top));
+ MP_CHECKOK(s_mp_grow(&g, top));
+ MP_CHECKOK(s_mp_grow(&temp, top));
+ MP_CHECKOK(s_mp_grow(&v, top));
+ MP_CHECKOK(s_mp_grow(&r, top));
+
+ /* Upper bound for the total iterations. */
+ flen = mpl_significant_bits(&f);
+ glen = mpl_significant_bits(&g);
+ its = 4 + 3 * ((flen >= glen) ? flen : glen);
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
+#endif
- if (a == c) {
- if ((res = mp_init_copy(&x, a)) != MP_OKAY)
- return res;
- if (a == m)
- m = &x;
- a = &x;
- } else if (m == c) {
- if ((res = mp_init_copy(&x, m)) != MP_OKAY)
- return res;
- m = &x;
- } else {
- MP_DIGITS(&x) = 0;
+ for (i = 0; i < its; i++) {
+ /* Step 1: conditional swap. */
+ /* Set cond if delta > 0 and g is odd. */
+ cond = (-delta >> (8 * sizeof(delta) - 1)) & DIGIT(&g, 0) & 1;
+ /* If cond is set replace (delta,f,v) with (-delta,-f,-v). */
+ delta = (-cond & -delta) | ((cond - 1) & delta);
+ SIGN(&f) ^= cond;
+ SIGN(&v) ^= cond;
+ /* If cond is set swap (f,v) with (g,r). */
+ MP_CHECKOK(mp_cswap(cond, &f, &g, top));
+ MP_CHECKOK(mp_cswap(cond, &v, &r, top));
+
+ /* Step 2: elemination. */
+ /* Update delta */
+ delta++;
+ /* If g is odd replace r with (r+v). */
+ MP_CHECKOK(mp_add(&r, &v, &temp));
+ MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &r, &temp, top));
+ /* If g is odd, right shift (g+f) else right shift g. */
+ MP_CHECKOK(mp_add(&g, &f, &temp));
+ MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &g, &temp, top));
+ s_mp_div_2(&g);
+ /*
+ If r is even, right shift it.
+ If r is odd, right shift (r+m) which is even because m is odd.
+ We want the result modulo m so adding in multiples of m here vanish.
+ */
+ MP_CHECKOK(mp_add(&r, m, &temp));
+ MP_CHECKOK(mp_cswap((DIGIT(&r, 0) & 1), &r, &temp, top));
+ s_mp_div_2(&r);
}
- MP_CHECKOK(s_mp_almost_inverse(a, m, c));
- k = res;
- MP_CHECKOK(s_mp_fixup_reciprocal(c, m, k, c));
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+ /* We have the inverse in v, propagate sign from f. */
+ SIGN(&v) ^= SIGN(&f);
+ /* GCD is in f, take the absolute value. */
+ SIGN(&f) = ZPOS;
+
+ /* If gcd != 1, not invertible. */
+ if (mp_cmp_d(&f, 1) != MP_EQ) {
+ res = MP_UNDEF;
+ goto CLEANUP;
+ }
+
+ /* Return inverse modulo m. */
+ MP_CHECKOK(mp_mod(&v, m, c));
+
CLEANUP:
- mp_clear(&x);
+ while (last >= 0)
+ mp_clear(clear[last--]);
return res;
}
@@ -2218,13 +2305,24 @@ s_mp_invmod_2d(const mp_int *a, mp_size k, mp_int *c)
if (mp_iseven(a))
return MP_UNDEF;
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
+#endif
if (k <= MP_DIGIT_BIT) {
mp_digit i = s_mp_invmod_radix(MP_DIGIT(a, 0));
+ /* propagate the sign from mp_int */
+ i = (i ^ -(mp_digit)SIGN(a)) + (mp_digit)SIGN(a);
if (k < MP_DIGIT_BIT)
i &= ((mp_digit)1 << k) - (mp_digit)1;
mp_set(c, i);
return MP_OKAY;
}
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
MP_DIGITS(&t0) = 0;
MP_DIGITS(&t1) = 0;
MP_DIGITS(&val) = 0;
@@ -2831,6 +2929,8 @@ s_mp_clamp(mp_int *mp)
while (used > 1 && DIGIT(mp, used - 1) == 0)
--used;
MP_USED(mp) = used;
+ if (used == 1 && DIGIT(mp, 0) == 0)
+ MP_SIGN(mp) = ZPOS;
} /* end s_mp_clamp() */
/* }}} */
@@ -2908,37 +3008,36 @@ mp_err
s_mp_mul_2d(mp_int *mp, mp_digit d)
{
mp_err res;
- mp_digit dshift, bshift;
- mp_digit mask;
+ mp_digit dshift, rshift, mask, x, prev = 0;
+ mp_digit *pa = NULL;
+ int i;
ARGCHK(mp != NULL, MP_BADARG);
dshift = d / MP_DIGIT_BIT;
- bshift = d % MP_DIGIT_BIT;
+ d %= MP_DIGIT_BIT;
+ /* mp_digit >> rshift is undefined behavior for rshift >= MP_DIGIT_BIT */
+ /* mod and corresponding mask logic avoid that when d = 0 */
+ rshift = MP_DIGIT_BIT - d;
+ rshift %= MP_DIGIT_BIT;
+ /* mask = (2**d - 1) * 2**(w-d) mod 2**w */
+ mask = (DIGIT_MAX << rshift) + 1;
+ mask &= DIGIT_MAX - 1;
/* bits to be shifted out of the top word */
- if (bshift) {
- mask = (mp_digit)~0 << (MP_DIGIT_BIT - bshift);
- mask &= MP_DIGIT(mp, MP_USED(mp) - 1);
- } else {
- mask = 0;
- }
+ x = MP_DIGIT(mp, MP_USED(mp) - 1) & mask;
- if (MP_OKAY != (res = s_mp_pad(mp, MP_USED(mp) + dshift + (mask != 0))))
+ if (MP_OKAY != (res = s_mp_pad(mp, MP_USED(mp) + dshift + (x != 0))))
return res;
if (dshift && MP_OKAY != (res = s_mp_lshd(mp, dshift)))
return res;
- if (bshift) {
- mp_digit *pa = MP_DIGITS(mp);
- mp_digit *alim = pa + MP_USED(mp);
- mp_digit prev = 0;
+ pa = MP_DIGITS(mp) + dshift;
- for (pa += dshift; pa < alim;) {
- mp_digit x = *pa;
- *pa++ = (x << bshift) | prev;
- prev = x >> (DIGIT_BIT - bshift);
- }
+ for (i = MP_USED(mp) - dshift; i > 0; i--) {
+ x = *pa;
+ *pa++ = (x << d) | prev;
+ prev = (x & mask) >> rshift;
}
s_mp_clamp(mp);
@@ -3077,18 +3176,20 @@ void
s_mp_div_2d(mp_int *mp, mp_digit d)
{
int ix;
- mp_digit save, next, mask;
+ mp_digit save, next, mask, lshift;
s_mp_rshd(mp, d / DIGIT_BIT);
d %= DIGIT_BIT;
- if (d) {
- mask = ((mp_digit)1 << d) - 1;
- save = 0;
- for (ix = USED(mp) - 1; ix >= 0; ix--) {
- next = DIGIT(mp, ix) & mask;
- DIGIT(mp, ix) = (DIGIT(mp, ix) >> d) | (save << (DIGIT_BIT - d));
- save = next;
- }
+ /* mp_digit << lshift is undefined behavior for lshift >= MP_DIGIT_BIT */
+ /* mod and corresponding mask logic avoid that when d = 0 */
+ lshift = DIGIT_BIT - d;
+ lshift %= DIGIT_BIT;
+ mask = ((mp_digit)1 << d) - 1;
+ save = 0;
+ for (ix = USED(mp) - 1; ix >= 0; ix--) {
+ next = DIGIT(mp, ix) & mask;
+ DIGIT(mp, ix) = (save << lshift) | (DIGIT(mp, ix) >> d);
+ save = next;
}
s_mp_clamp(mp);
@@ -4841,5 +4942,44 @@ mp_to_fixlen_octets(const mp_int *mp, unsigned char *str, mp_size length)
} /* end mp_to_fixlen_octets() */
/* }}} */
+/* {{{ mp_cswap(condition, a, b, numdigits) */
+/* performs a conditional swap between mp_int. */
+mp_err
+mp_cswap(mp_digit condition, mp_int *a, mp_int *b, mp_size numdigits)
+{
+ mp_digit x;
+ unsigned int i;
+ mp_err res = 0;
+
+ /* if pointers are equal return */
+ if (a == b)
+ return res;
+
+ if (MP_ALLOC(a) < numdigits || MP_ALLOC(b) < numdigits) {
+ MP_CHECKOK(s_mp_grow(a, numdigits));
+ MP_CHECKOK(s_mp_grow(b, numdigits));
+ }
+
+ condition = ((~condition & ((condition - 1))) >> (MP_DIGIT_BIT - 1)) - 1;
+
+ x = (USED(a) ^ USED(b)) & condition;
+ USED(a) ^= x;
+ USED(b) ^= x;
+
+ x = (SIGN(a) ^ SIGN(b)) & condition;
+ SIGN(a) ^= x;
+ SIGN(b) ^= x;
+
+ for (i = 0; i < numdigits; i++) {
+ x = (DIGIT(a, i) ^ DIGIT(b, i)) & condition;
+ DIGIT(a, i) ^= x;
+ DIGIT(b, i) ^= x;
+ }
+
+CLEANUP:
+ return res;
+} /* end mp_cswap() */
+/* }}} */
+
/*------------------------------------------------------------------------*/
/* HERE THERE BE DRAGONS */
diff --git a/lib/freebl/mpi/mpi.h b/lib/freebl/mpi/mpi.h
index af608b43d..b1a07a61d 100644
--- a/lib/freebl/mpi/mpi.h
+++ b/lib/freebl/mpi/mpi.h
@@ -267,6 +267,7 @@ mp_size mp_trailing_zeros(const mp_int *mp);
void freebl_cpuid(unsigned long op, unsigned long *eax,
unsigned long *ebx, unsigned long *ecx,
unsigned long *edx);
+mp_err mp_cswap(mp_digit condition, mp_int *a, mp_int *b, mp_size numdigits);
#define MP_CHECKOK(x) \
if (MP_OKAY > (res = (x))) \
diff --git a/lib/freebl/mpi/mplogic.c b/lib/freebl/mpi/mplogic.c
index 31fc56d34..db19cff13 100644
--- a/lib/freebl/mpi/mplogic.c
+++ b/lib/freebl/mpi/mplogic.c
@@ -407,35 +407,54 @@ mpl_get_bits(const mp_int *a, mp_size lsbNum, mp_size numBits)
return (mp_err)mask;
}
+#define LZCNTLOOP(i) \
+ do { \
+ x = d >> (i); \
+ mask = (0 - x); \
+ mask = (0 - (mask >> (MP_DIGIT_BIT - 1))); \
+ bits += (i)&mask; \
+ d ^= (x ^ d) & mask; \
+ } while (0)
+
/*
mpl_significant_bits
- returns number of significnant bits in abs(a).
+ returns number of significant bits in abs(a).
+ In other words: floor(lg(abs(a))) + 1.
returns 1 if value is zero.
*/
mp_size
mpl_significant_bits(const mp_int *a)
{
- mp_size bits = 0;
+ /*
+ start bits at 1.
+ lg(0) = 0 => bits = 1 by function semantics.
+ below does a binary search for the _position_ of the top bit set,
+ which is floor(lg(abs(a))) for a != 0.
+ */
+ mp_size bits = 1;
int ix;
ARGCHK(a != NULL, MP_BADARG);
for (ix = MP_USED(a); ix > 0;) {
- mp_digit d;
- d = MP_DIGIT(a, --ix);
- if (d) {
- while (d) {
- ++bits;
- d >>= 1;
- }
- break;
- }
+ mp_digit d, x, mask;
+ if ((d = MP_DIGIT(a, --ix)) == 0)
+ continue;
+#if !defined(MP_USE_UINT_DIGIT)
+ LZCNTLOOP(32);
+#endif
+ LZCNTLOOP(16);
+ LZCNTLOOP(8);
+ LZCNTLOOP(4);
+ LZCNTLOOP(2);
+ LZCNTLOOP(1);
+ break;
}
bits += ix * MP_DIGIT_BIT;
- if (!bits)
- bits = 1;
return bits;
}
+#undef LZCNTLOOP
+
/*------------------------------------------------------------------------*/
/* HERE THERE BE DRAGONS */