diff options
Diffstat (limited to 'libgo/go/math/big/nat.go')
-rw-r--r-- | libgo/go/math/big/nat.go | 167 |
1 files changed, 91 insertions, 76 deletions
diff --git a/libgo/go/math/big/nat.go b/libgo/go/math/big/nat.go index 6d81823bb4a..2d5a5c9587d 100644 --- a/libgo/go/math/big/nat.go +++ b/libgo/go/math/big/nat.go @@ -236,7 +236,7 @@ func karatsubaSub(z, x nat, n int) { // Operands that are shorter than karatsubaThreshold are multiplied using // "grade school" multiplication; for longer operands the Karatsuba algorithm // is used. -var karatsubaThreshold int = 32 // computed by calibrate.go +var karatsubaThreshold int = 40 // computed by calibrate.go // karatsuba multiplies x and y and leaves the result in z. // Both x and y must have the same length n and n must be a @@ -342,7 +342,7 @@ func alias(x, y nat) bool { return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] } -// addAt implements z += x*(1<<(_W*i)); z must be long enough. +// addAt implements z += x<<(_W*i); z must be long enough. // (we don't use nat.add because we need z to stay the same // slice, and we don't need to normalize z after each addition) func addAt(z, x nat, i int) { @@ -405,8 +405,8 @@ func (z nat) mul(x, y nat) nat { // determine Karatsuba length k such that // - // x = x1*b + x0 - // y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) + // x = xh*b + x0 (0 <= x0 < b) + // y = yh*b + y0 (0 <= y0 < b) // b = 1<<(_W*k) ("base" of digits xi, yi) // k := karatsubaLen(n) @@ -417,27 +417,44 @@ func (z nat) mul(x, y nat) nat { y0 := y[0:k] // y0 is not normalized z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y karatsuba(z, x0, y0) - z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage - - // If x1 and/or y1 are not 0, add missing terms to z explicitly: - // - // m+n 2*k 0 - // z = [ ... | x0*y0 ] - // + [ x1*y1 ] - // + [ x1*y0 ] - // + [ x0*y1 ] + z = z[0 : m+n] // z has final length but may be incomplete + z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) + + // If xh != 0 or yh != 0, add the missing terms to z. For + // + // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) + // yh = y1*b (0 <= y1 < b) + // + // the missing terms are + // + // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 + // + // since all the yi for i > 1 are 0 by choice of k: If any of them + // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would + // be a larger valid threshold contradicting the assumption about k. // if k < n || m != n { - x1 := x[k:] // x1 is normalized because x is - y1 := y[k:] // y1 is normalized because y is var t nat - t = t.mul(x1, y1) - copy(z[2*k:], t) - z[2*k+len(t):].clear() // upper portion of z is garbage - t = t.mul(x1, y0.norm()) - addAt(z, t, k) - t = t.mul(x0.norm(), y1) + + // add x0*y1*b + x0 := x0.norm() + y1 := y[k:] // y1 is normalized because y is + t = t.mul(x0, y1) // update t so we don't lose t's underlying array addAt(z, t, k) + + // add xi*y0<<i, xi*y1*b<<(i+k) + y0 := y0.norm() + for i := k; i < len(x); i += k { + xi := x[i:] + if len(xi) > k { + xi = xi[:k] + } + xi = xi.norm() + t = t.mul(xi, y0) + addAt(z, t, i) + t = t.mul(xi, y1) + addAt(z, t, i+k) + } } return z.norm() @@ -493,14 +510,9 @@ func (z nat) div(z2, u, v nat) (q, r nat) { } if len(v) == 1 { - var rprime Word - q, rprime = z.divW(u, v[0]) - if rprime > 0 { - r = z2.make(1) - r[0] = rprime - } else { - r = z2.make(0) - } + var r2 Word + q, r2 = z.divW(u, v[0]) + r = z2.setWord(r2) return } @@ -740,7 +752,7 @@ func (x nat) string(charset string) string { // convert power of two and non power of two bases separately if b == b&-b { // shift is base-b digit size in bits - shift := uint(trailingZeroBits(b)) // shift > 0 because b >= 2 + shift := trailingZeroBits(b) // shift > 0 because b >= 2 mask := Word(1)<<shift - 1 w := x[0] nbits := uint(_W) // number of unprocessed bits in w @@ -993,10 +1005,9 @@ var deBruijn64Lookup = []byte{ 54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6, } -// trailingZeroBits returns the number of consecutive zero bits on the right -// side of the given Word. -// See Knuth, volume 4, section 7.3.1 -func trailingZeroBits(x Word) int { +// trailingZeroBits returns the number of consecutive least significant zero +// bits of x. +func trailingZeroBits(x Word) uint { // x & -x leaves only the right-most bit set in the word. Let k be the // index of that bit. Since only a single bit is set, the value is two // to the power of k. Multiplying by a power of two is equivalent to @@ -1005,18 +1016,33 @@ func trailingZeroBits(x Word) int { // Therefore, if we have a left shifted version of this constant we can // find by how many bits it was shifted by looking at which six bit // substring ended up at the top of the word. + // (Knuth, volume 4, section 7.3.1) switch _W { case 32: - return int(deBruijn32Lookup[((x&-x)*deBruijn32)>>27]) + return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27]) case 64: - return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) + return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) default: - panic("Unknown word size") + panic("unknown word size") } return 0 } +// trailingZeroBits returns the number of consecutive least significant zero +// bits of x. +func (x nat) trailingZeroBits() uint { + if len(x) == 0 { + return 0 + } + var i uint + for x[i] == 0 { + i++ + } + // x[i] != 0 + return i*_W + trailingZeroBits(x[i]) +} + // z = x << s func (z nat) shl(x nat, s uint) nat { m := len(x) @@ -1169,29 +1195,6 @@ func (x nat) modW(d Word) (r Word) { return divWVW(q, 0, x, d) } -// powersOfTwoDecompose finds q and k with x = q * 1<<k and q is odd, or q and k are 0. -func (x nat) powersOfTwoDecompose() (q nat, k int) { - if len(x) == 0 { - return x, 0 - } - - // One of the words must be non-zero by definition, - // so this loop will terminate with i < len(x), and - // i is the number of 0 words. - i := 0 - for x[i] == 0 { - i++ - } - n := trailingZeroBits(x[i]) // x[i] != 0 - - q = make(nat, len(x)-i) - shrVU(q, x[i:], uint(n)) - - q = q.norm() - k = i*_W + n - return -} - // random creates a random integer in [0..limit), using the space in z if // possible. n is the bit length of limit. func (z nat) random(rand *rand.Rand, limit nat, n int) nat { @@ -1207,17 +1210,19 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat { mask := Word((1 << bitLengthOfMSW) - 1) for { - for i := range z { - switch _W { - case 32: + switch _W { + case 32: + for i := range z { z[i] = Word(rand.Uint32()) - case 64: + } + case 64: + for i := range z { z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 } + default: + panic("unknown word size") } - z[len(limit)-1] &= mask - if z.cmp(limit) < 0 { break } @@ -1230,7 +1235,7 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat { // reuses the storage of z if possible. func (z nat) expNN(x, y, m nat) nat { if alias(z, x) || alias(z, y) { - // We cannot allow in place modification of x or y. + // We cannot allow in-place modification of x or y. z = nil } @@ -1259,15 +1264,21 @@ func (z nat) expNN(x, y, m nat) nat { // we also multiply by x, thus adding one to the power. w := _W - int(shift) + // zz and r are used to avoid allocating in mul and div as + // otherwise the arguments would alias. + var zz, r nat for j := 0; j < w; j++ { - z = z.mul(z, z) + zz = zz.mul(z, z) + zz, z = z, zz if v&mask != 0 { - z = z.mul(z, x) + zz = zz.mul(z, x) + zz, z = z, zz } if m != nil { - q, z = q.div(z, z, m) + zz, r = zz.div(r, z, m) + zz, r, q, z = q, z, zz, r } v <<= 1 @@ -1277,14 +1288,17 @@ func (z nat) expNN(x, y, m nat) nat { v = y[i] for j := 0; j < _W; j++ { - z = z.mul(z, z) + zz = zz.mul(z, z) + zz, z = z, zz if v&mask != 0 { - z = z.mul(z, x) + zz = zz.mul(z, x) + zz, z = z, zz } if m != nil { - q, z = q.div(z, z, m) + zz, r = zz.div(r, z, m) + zz, r, q, z = q, z, zz, r } v <<= 1 @@ -1343,8 +1357,9 @@ func (n nat) probablyPrime(reps int) bool { } nm1 := nat(nil).sub(n, natOne) - // 1<<k * q = nm1; - q, k := nm1.powersOfTwoDecompose() + // determine q, k such that nm1 = q << k + k := nm1.trailingZeroBits() + q := nat(nil).shr(nm1, k) nm3 := nat(nil).sub(nm1, natTwo) rand := rand.New(rand.NewSource(int64(n[0]))) @@ -1360,7 +1375,7 @@ NextRandom: if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { continue } - for j := 1; j < k; j++ { + for j := uint(1); j < k; j++ { y = y.mul(y, y) quotient, y = quotient.div(y, y, n) if y.cmp(nm1) == 0 { |