summaryrefslogtreecommitdiff
path: root/libgo/go/math/big/nat.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/math/big/nat.go')
-rw-r--r--libgo/go/math/big/nat.go167
1 files changed, 91 insertions, 76 deletions
diff --git a/libgo/go/math/big/nat.go b/libgo/go/math/big/nat.go
index 6d81823bb4a..2d5a5c9587d 100644
--- a/libgo/go/math/big/nat.go
+++ b/libgo/go/math/big/nat.go
@@ -236,7 +236,7 @@ func karatsubaSub(z, x nat, n int) {
// Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used.
-var karatsubaThreshold int = 32 // computed by calibrate.go
+var karatsubaThreshold int = 40 // computed by calibrate.go
// karatsuba multiplies x and y and leaves the result in z.
// Both x and y must have the same length n and n must be a
@@ -342,7 +342,7 @@ func alias(x, y nat) bool {
return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
}
-// addAt implements z += x*(1<<(_W*i)); z must be long enough.
+// addAt implements z += x<<(_W*i); z must be long enough.
// (we don't use nat.add because we need z to stay the same
// slice, and we don't need to normalize z after each addition)
func addAt(z, x nat, i int) {
@@ -405,8 +405,8 @@ func (z nat) mul(x, y nat) nat {
// determine Karatsuba length k such that
//
- // x = x1*b + x0
- // y = y1*b + y0 (and k <= len(y), which implies k <= len(x))
+ // x = xh*b + x0 (0 <= x0 < b)
+ // y = yh*b + y0 (0 <= y0 < b)
// b = 1<<(_W*k) ("base" of digits xi, yi)
//
k := karatsubaLen(n)
@@ -417,27 +417,44 @@ func (z nat) mul(x, y nat) nat {
y0 := y[0:k] // y0 is not normalized
z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
karatsuba(z, x0, y0)
- z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage
-
- // If x1 and/or y1 are not 0, add missing terms to z explicitly:
- //
- // m+n 2*k 0
- // z = [ ... | x0*y0 ]
- // + [ x1*y1 ]
- // + [ x1*y0 ]
- // + [ x0*y1 ]
+ z = z[0 : m+n] // z has final length but may be incomplete
+ z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
+
+ // If xh != 0 or yh != 0, add the missing terms to z. For
+ //
+ // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
+ // yh = y1*b (0 <= y1 < b)
+ //
+ // the missing terms are
+ //
+ // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
+ //
+ // since all the yi for i > 1 are 0 by choice of k: If any of them
+ // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
+ // be a larger valid threshold contradicting the assumption about k.
//
if k < n || m != n {
- x1 := x[k:] // x1 is normalized because x is
- y1 := y[k:] // y1 is normalized because y is
var t nat
- t = t.mul(x1, y1)
- copy(z[2*k:], t)
- z[2*k+len(t):].clear() // upper portion of z is garbage
- t = t.mul(x1, y0.norm())
- addAt(z, t, k)
- t = t.mul(x0.norm(), y1)
+
+ // add x0*y1*b
+ x0 := x0.norm()
+ y1 := y[k:] // y1 is normalized because y is
+ t = t.mul(x0, y1) // update t so we don't lose t's underlying array
addAt(z, t, k)
+
+ // add xi*y0<<i, xi*y1*b<<(i+k)
+ y0 := y0.norm()
+ for i := k; i < len(x); i += k {
+ xi := x[i:]
+ if len(xi) > k {
+ xi = xi[:k]
+ }
+ xi = xi.norm()
+ t = t.mul(xi, y0)
+ addAt(z, t, i)
+ t = t.mul(xi, y1)
+ addAt(z, t, i+k)
+ }
}
return z.norm()
@@ -493,14 +510,9 @@ func (z nat) div(z2, u, v nat) (q, r nat) {
}
if len(v) == 1 {
- var rprime Word
- q, rprime = z.divW(u, v[0])
- if rprime > 0 {
- r = z2.make(1)
- r[0] = rprime
- } else {
- r = z2.make(0)
- }
+ var r2 Word
+ q, r2 = z.divW(u, v[0])
+ r = z2.setWord(r2)
return
}
@@ -740,7 +752,7 @@ func (x nat) string(charset string) string {
// convert power of two and non power of two bases separately
if b == b&-b {
// shift is base-b digit size in bits
- shift := uint(trailingZeroBits(b)) // shift > 0 because b >= 2
+ shift := trailingZeroBits(b) // shift > 0 because b >= 2
mask := Word(1)<<shift - 1
w := x[0]
nbits := uint(_W) // number of unprocessed bits in w
@@ -993,10 +1005,9 @@ var deBruijn64Lookup = []byte{
54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
}
-// trailingZeroBits returns the number of consecutive zero bits on the right
-// side of the given Word.
-// See Knuth, volume 4, section 7.3.1
-func trailingZeroBits(x Word) int {
+// trailingZeroBits returns the number of consecutive least significant zero
+// bits of x.
+func trailingZeroBits(x Word) uint {
// x & -x leaves only the right-most bit set in the word. Let k be the
// index of that bit. Since only a single bit is set, the value is two
// to the power of k. Multiplying by a power of two is equivalent to
@@ -1005,18 +1016,33 @@ func trailingZeroBits(x Word) int {
// Therefore, if we have a left shifted version of this constant we can
// find by how many bits it was shifted by looking at which six bit
// substring ended up at the top of the word.
+ // (Knuth, volume 4, section 7.3.1)
switch _W {
case 32:
- return int(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
+ return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
case 64:
- return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
+ return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
default:
- panic("Unknown word size")
+ panic("unknown word size")
}
return 0
}
+// trailingZeroBits returns the number of consecutive least significant zero
+// bits of x.
+func (x nat) trailingZeroBits() uint {
+ if len(x) == 0 {
+ return 0
+ }
+ var i uint
+ for x[i] == 0 {
+ i++
+ }
+ // x[i] != 0
+ return i*_W + trailingZeroBits(x[i])
+}
+
// z = x << s
func (z nat) shl(x nat, s uint) nat {
m := len(x)
@@ -1169,29 +1195,6 @@ func (x nat) modW(d Word) (r Word) {
return divWVW(q, 0, x, d)
}
-// powersOfTwoDecompose finds q and k with x = q * 1<<k and q is odd, or q and k are 0.
-func (x nat) powersOfTwoDecompose() (q nat, k int) {
- if len(x) == 0 {
- return x, 0
- }
-
- // One of the words must be non-zero by definition,
- // so this loop will terminate with i < len(x), and
- // i is the number of 0 words.
- i := 0
- for x[i] == 0 {
- i++
- }
- n := trailingZeroBits(x[i]) // x[i] != 0
-
- q = make(nat, len(x)-i)
- shrVU(q, x[i:], uint(n))
-
- q = q.norm()
- k = i*_W + n
- return
-}
-
// random creates a random integer in [0..limit), using the space in z if
// possible. n is the bit length of limit.
func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
@@ -1207,17 +1210,19 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
mask := Word((1 << bitLengthOfMSW) - 1)
for {
- for i := range z {
- switch _W {
- case 32:
+ switch _W {
+ case 32:
+ for i := range z {
z[i] = Word(rand.Uint32())
- case 64:
+ }
+ case 64:
+ for i := range z {
z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
}
+ default:
+ panic("unknown word size")
}
-
z[len(limit)-1] &= mask
-
if z.cmp(limit) < 0 {
break
}
@@ -1230,7 +1235,7 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
// reuses the storage of z if possible.
func (z nat) expNN(x, y, m nat) nat {
if alias(z, x) || alias(z, y) {
- // We cannot allow in place modification of x or y.
+ // We cannot allow in-place modification of x or y.
z = nil
}
@@ -1259,15 +1264,21 @@ func (z nat) expNN(x, y, m nat) nat {
// we also multiply by x, thus adding one to the power.
w := _W - int(shift)
+ // zz and r are used to avoid allocating in mul and div as
+ // otherwise the arguments would alias.
+ var zz, r nat
for j := 0; j < w; j++ {
- z = z.mul(z, z)
+ zz = zz.mul(z, z)
+ zz, z = z, zz
if v&mask != 0 {
- z = z.mul(z, x)
+ zz = zz.mul(z, x)
+ zz, z = z, zz
}
if m != nil {
- q, z = q.div(z, z, m)
+ zz, r = zz.div(r, z, m)
+ zz, r, q, z = q, z, zz, r
}
v <<= 1
@@ -1277,14 +1288,17 @@ func (z nat) expNN(x, y, m nat) nat {
v = y[i]
for j := 0; j < _W; j++ {
- z = z.mul(z, z)
+ zz = zz.mul(z, z)
+ zz, z = z, zz
if v&mask != 0 {
- z = z.mul(z, x)
+ zz = zz.mul(z, x)
+ zz, z = z, zz
}
if m != nil {
- q, z = q.div(z, z, m)
+ zz, r = zz.div(r, z, m)
+ zz, r, q, z = q, z, zz, r
}
v <<= 1
@@ -1343,8 +1357,9 @@ func (n nat) probablyPrime(reps int) bool {
}
nm1 := nat(nil).sub(n, natOne)
- // 1<<k * q = nm1;
- q, k := nm1.powersOfTwoDecompose()
+ // determine q, k such that nm1 = q << k
+ k := nm1.trailingZeroBits()
+ q := nat(nil).shr(nm1, k)
nm3 := nat(nil).sub(nm1, natTwo)
rand := rand.New(rand.NewSource(int64(n[0])))
@@ -1360,7 +1375,7 @@ NextRandom:
if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
continue
}
- for j := 1; j < k; j++ {
+ for j := uint(1); j < k; j++ {
y = y.mul(y, y)
quotient, y = quotient.div(y, y, n)
if y.cmp(nm1) == 0 {