From d22d634513b2d9070875c40fec5a37648c216d43 Mon Sep 17 00:00:00 2001 From: Denis Bernard Date: Mon, 11 May 2020 23:45:11 +0200 Subject: [PATCH] dec: implement karatsuba mul and sqr + basicSqr --- dec.go | 361 ++++++++++++++++++++++++++++++++++------------ dec_arith_test.go | 16 +- dec_test.go | 12 +- decimal_test.go | 4 +- stdlib.go | 31 ++++ stdlib_test.go | 9 -- 6 files changed, 315 insertions(+), 118 deletions(-) delete mode 100644 stdlib_test.go diff --git a/dec.go b/dec.go index 78c0f92..9e36e03 100644 --- a/dec.go +++ b/dec.go @@ -138,7 +138,7 @@ func (x dec) toNat(z []Word) []Word { func (z dec) setInt(x *big.Int) dec { bb := x.Bits() // TODO(db47h): here we cannot directly copy(b, bb) - // because big.Word != decimal.Word + // because big.Word != decimal.Word. b := make([]Word, len(bb)) for i := 0; i < len(b) && i < len(bb); i++ { b[i] = Word(bb[i]) @@ -507,8 +507,6 @@ func (z dec) shr(x dec, s uint) dec { // Operands that are shorter than basicSqrThreshold are squared using // "grade school" multiplication; for operands longer than karatsubaSqrThreshold // we use the Karatsuba algorithm optimized for x == y. -var decBasicSqrThreshold = 20 // computed by calibrate_test.go -var decKaratsubaSqrThreshold = 260 // computed by calibrate_test.go // z = x*x func (z dec) sqr(x dec) dec { @@ -527,45 +525,102 @@ func (z dec) sqr(x dec) dec { z = nil // z is an alias for x - cannot reuse } - // if n < decBasicSqrThreshold { - z = z.make(2 * n) - decBasicMul(z, x, x) - return z.norm() - // } - // TODO(db47h): implement basicSqr - // if n < decKaratsubaSqrThreshold { - // z = z.make(2 * n) - // basicSqr(z, x) - // return z.norm() - // } - // TODO(db47h): implement karatsuba algorithm + if n < basicSqrThreshold { + z = z.make(2 * n) + decBasicMul(z, x, x) + return z.norm() + } + if n < karatsubaSqrThreshold { + z = z.make(2 * n) + decBasicSqr(z, x) + return z.norm() + } // Use Karatsuba multiplication optimized for x == y. // The algorithm and layout of z are the same as for mul. // z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2 - // k := karatsubaLen(n, karatsubaSqrThreshold) - - // x0 := x[0:k] - // z = z.make(max(6*k, 2*n)) - // karatsubaSqr(z, x0) // z = x0^2 - // z = z[0 : 2*n] - // z[2*k:].clear() - - // if k < n { - // tp := getNat(2 * k) - // t := *tp - // x0 := x0.norm() - // x1 := x[k:] - // t = t.mul(x0, x1) - // addAt(z, t, k) - // addAt(z, t, k) // z = 2*x1*x0*b + x0^2 - // t = t.sqr(x1) - // addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2 - // putNat(tp) - // } + k := karatsubaLen(n, karatsubaSqrThreshold) + + x0 := x[0:k] + z = z.make(max(6*k, 2*n)) + decKaratsubaSqr(z, x0) // z = x0^2 + z = z[0 : 2*n] + z[2*k:].clear() + + if k < n { + tp := getDec(2 * k) + t := *tp + x0 := x0.norm() + x1 := x[k:] + t = t.mul(x0, x1) + decAddAt(z, t, k) + decAddAt(z, t, k) // z = 2*x1*x0*b + x0^2 + t = t.sqr(x1) + decAddAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2 + putDec(tp) + } - // return z.norm() + return z.norm() +} + +// basicSqr sets z = x*x and is asymptotically faster than basicMul +// by about a factor of 2, but slower for small arguments due to overhead. +// Requirements: len(x) > 0, len(z) == 2*len(x) +// The (non-normalized) result is placed in z. +func decBasicSqr(z, x dec) { + n := len(x) + tp := getDec(2 * n) + t := *tp // temporary variable to hold the products + t.clear() + z[1], z[0] = mul10WW_g(x[0], x[0]) // the initial square + for i := 1; i < n; i++ { + d := x[i] + // z collects the squares x[i] * x[i] + z[2*i+1], z[2*i] = mul10WW_g(d, d) + // t collects the products x[i] * x[j] where j < i + t[2*i] = addMul10VVW(t[i:2*i], x[0:i], d) + } + // t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products + t[2*n-1] = mulAdd10VWW(t[1:2*n-1], t[1:2*n-1], 2, 0) + add10VV(z, z, t) // combine the result + putDec(tp) +} + +// decKaratsubaSqr squares x and leaves the result in z. +// len(x) must be a power of 2 and len(z) >= 6*len(x). +// The (non-normalized) result is placed in z[0 : 2*len(x)]. +// +// The algorithm and the layout of z are the same as for karatsuba. +func decKaratsubaSqr(z, x dec) { + n := len(x) + + if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 { + decBasicSqr(z[:2*n], x) + return + } + + n2 := n >> 1 + x1, x0 := x[n2:], x[0:n2] + + decKaratsubaSqr(z, x0) + decKaratsubaSqr(z[n:], x1) + + // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0 + xd := z[2*n : 2*n+n2] + if sub10VV(xd, x1, x0) != 0 { + sub10VV(xd, x0, x1) + } + + p := z[n*3:] + decKaratsubaSqr(p, xd) + + r := z[n*4:] + copy(r, z[:n*2]) + + decKaratsubaAdd(z[n2:], r, n) + decKaratsubaAdd(z[n2:], r[n:], n) + decKaratsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0 } // decBasicMul multiplies x and y and leaves the result in z. @@ -599,71 +654,71 @@ func (z dec) mul(x, y dec) dec { } // use basic multiplication if the numbers are small - // if n < karatsubaThreshold { - z = z.make(m + n) - decBasicMul(z, x, y) - return z.norm() - // } + if n < karatsubaThreshold { + z = z.make(m + n) + decBasicMul(z, x, y) + return z.norm() + } // m >= n && n >= karatsubaThreshold && n >= 2 - // // determine Karatsuba length k such that - // // - // // x = xh*b + x0 (0 <= x0 < b) - // // y = yh*b + y0 (0 <= y0 < b) - // // b = 1<<(_W*k) ("base" of digits xi, yi) - // // - // k := karatsubaLen(n, karatsubaThreshold) - // // k <= n + // determine Karatsuba length k such that + // + // x = xh*b + x0 (0 <= x0 < b) + // y = yh*b + y0 (0 <= y0 < b) + // b = 10**(_DW*k) ("base" of digits xi, yi) + // + k := karatsubaLen(n, karatsubaThreshold) + // k <= n // // multiply x0 and y0 via Karatsuba - // x0 := x[0:k] // x0 is not normalized - // y0 := y[0:k] // y0 is not normalized - // z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y - // karatsuba(z, x0, y0) - // z = z[0 : m+n] // z has final length but may be incomplete - // z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) - - // // If xh != 0 or yh != 0, add the missing terms to z. For - // // - // // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) - // // yh = y1*b (0 <= y1 < b) - // // - // // the missing terms are - // // - // // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 - // // - // // since all the yi for i > 1 are 0 by choice of k: If any of them - // // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would - // // be a larger valid threshold contradicting the assumption about k. - // // - // if k < n || m != n { - // tp := getNat(3 * k) - // t := *tp - - // // add x0*y1*b - // x0 := x0.norm() - // y1 := y[k:] // y1 is normalized because y is - // t = t.mul(x0, y1) // update t so we don't lose t's underlying array - // addAt(z, t, k) - - // // add xi*y0< k { - // xi = xi[:k] - // } - // xi = xi.norm() - // t = t.mul(xi, y0) - // addAt(z, t, i) - // t = t.mul(xi, y1) - // addAt(z, t, i+k) - // } + x0 := x[0:k] // x0 is not normalized + y0 := y[0:k] // y0 is not normalized + z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y + decKaratsuba(z, x0, y0) + z = z[0 : m+n] // z has final length but may be incomplete + z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) + + // If xh != 0 or yh != 0, add the missing terms to z. For + // + // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) + // yh = y1*b (0 <= y1 < b) + // + // the missing terms are + // + // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 + // + // since all the yi for i > 1 are 0 by choice of k: If any of them + // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would + // be a larger valid threshold contradicting the assumption about k. + // + if k < n || m != n { + tp := getDec(3 * k) + t := *tp + + // add x0*y1*b + x0 := x0.norm() + y1 := y[k:] // y1 is normalized because y is + t = t.mul(x0, y1) // update t so we don't lose t's underlying array + decAddAt(z, t, k) + + // add xi*y0< k { + xi = xi[:k] + } + xi = xi.norm() + t = t.mul(xi, y0) + decAddAt(z, t, i) + t = t.mul(xi, y1) + decAddAt(z, t, i+k) + } - // putNat(tp) - // } + putDec(tp) + } - // return z.norm() + return z.norm() } // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; @@ -786,6 +841,8 @@ func (z dec) expNN(x, y, m dec) dec { // - len(z) >= len(u)-len(v) // // See Burnikel, Ziegler, "Fast Recursive Division", Algorithm 1 and 2. +// TODO(db47h): review https://pure.mpg.de/rest/items/item_1819444_4/component/file_2599480/content +// and make sure that when calling divBasic, the preconditions are met. func (z dec) divRecursive(u, v dec) { // Recursion depth is less than 2 log2(len(v)) // Allocate a slice of temporaries to be reused across recursion. @@ -944,3 +1001,117 @@ func decAddAt(z, x dec, i int) { } } } + +// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. +// Factored out for readability - do not use outside karatsuba. +func decKaratsubaAdd(z, x dec, n int) { + if c := add10VV(z[0:n], z, x); c != 0 { + add10VW(z[n:n+n>>1], z[n:], c) + } +} + +// Like karatsubaAdd, but does subtract. +func decKaratsubaSub(z, x dec, n int) { + if c := sub10VV(z[0:n], z, x); c != 0 { + sub10VW(z[n:n+n>>1], z[n:], c) + } +} + +// karatsuba multiplies x and y and leaves the result in z. +// Both x and y must have the same length n and n must be a +// power of 2. The result vector z must have len(z) >= 6*n. +// The (non-normalized) result is placed in z[0 : 2*n]. +func decKaratsuba(z, x, y dec) { + n := len(y) + + // Switch to basic multiplication if numbers are odd or small. + // (n is always even if karatsubaThreshold is even, but be + // conservative) + if n&1 != 0 || n < karatsubaThreshold || n < 2 { + decBasicMul(z, x, y) + return + } + // n&1 == 0 && n >= karatsubaThreshold && n >= 2 + + // Karatsuba multiplication is based on the observation that + // for two numbers x and y with: + // + // x = x1*b + x0 + // y = y1*b + y0 + // + // the product x*y can be obtained with 3 products z2, z1, z0 + // instead of 4: + // + // x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0 + // = z2*b*b + z1*b + z0 + // + // with: + // + // xd = x1 - x0 + // yd = y0 - y1 + // + // z1 = xd*yd + z2 + z0 + // = (x1-x0)*(y0 - y1) + z2 + z0 + // = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0 + // = x1*y0 - z2 - z0 + x0*y1 + z2 + z0 + // = x1*y0 + x0*y1 + + // split x, y into "digits" + n2 := n >> 1 // n2 >= 1 + x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 + y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 + + // z is used for the result and temporary storage: + // + // 6*n 5*n 4*n 3*n 2*n 1*n 0*n + // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] + // + // For each recursive call of karatsuba, an unused slice of + // z is passed in that has (at least) half the length of the + // caller's z. + + // compute z0 and z2 with the result "in place" in z + decKaratsuba(z, x0, y0) // z0 = x0*y0 + decKaratsuba(z[n:], x1, y1) // z2 = x1*y1 + + // compute xd (or the negative value if underflow occurs) + s := 1 // sign of product xd*yd + xd := z[2*n : 2*n+n2] + if sub10VV(xd, x1, x0) != 0 { // x1-x0 + s = -s + sub10VV(xd, x0, x1) // x0-x1 + } + + // compute yd (or the negative value if underflow occurs) + yd := z[2*n+n2 : 3*n] + if sub10VV(yd, y0, y1) != 0 { // y0-y1 + s = -s + sub10VV(yd, y1, y0) // y1-y0 + } + + // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 + // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 + p := z[n*3:] + decKaratsuba(p, xd, yd) + + // save original z2:z0 + // (ok to use upper half of z since we're done recursing) + r := z[n*4:] + copy(r, z[:n*2]) + + // add up all partial products + // + // 2*n n 0 + // z = [ z2 | z0 ] + // + [ z0 ] + // + [ z2 ] + // + [ p ] + // + decKaratsubaAdd(z[n2:], r, n) + decKaratsubaAdd(z[n2:], r[n:], n) + if s > 0 { + decKaratsubaAdd(z[n2:], p, n) + } else { + decKaratsubaSub(z[n2:], p, n) + } +} diff --git a/dec_arith_test.go b/dec_arith_test.go index 2123be2..942d5eb 100644 --- a/dec_arith_test.go +++ b/dec_arith_test.go @@ -37,7 +37,7 @@ func rnd10V(n int) []Word { return v } -func TestDiv10W(t *testing.T) { +func TestDecDiv10W(t *testing.T) { h, l := rnd10W(), Word(rnd.Uint64()) for i := 0; i < 1e7; i++ { q, r := div10W(h, l) @@ -50,7 +50,7 @@ func TestDiv10W(t *testing.T) { var benchH, benchL Word -func BenchmarkDiv10W_bits(b *testing.B) { +func BenchmarkDecDiv10W_bits(b *testing.B) { h, l := rnd10W(), Word(rnd.Uint64()) for i := 0; i < b.N; i++ { h, l := bits.Div(uint(h), uint(l), _DB) @@ -58,7 +58,7 @@ func BenchmarkDiv10W_bits(b *testing.B) { } } -func BenchmarkDiv10W_mul(b *testing.B) { +func BenchmarkDecDiv10W_mul(b *testing.B) { h, l := rnd10W(), Word(rnd.Uint64()) for i := 0; i < b.N; i++ { benchH, benchL = div10W(h, l) @@ -99,7 +99,7 @@ func testFun10VV(t *testing.T, msg string, f fun10VV, a arg10VV) { } } -func TestFun10VV(t *testing.T) { +func TestDecFun10VV(t *testing.T) { for _, a := range sum10VV { arg := a testFun10VV(t, "add10VV_g", add10VV_g, arg) @@ -119,7 +119,7 @@ func TestFun10VV(t *testing.T) { } } -func BenchmarkAdd10VV(b *testing.B) { +func BenchmarkDecAdd10VV(b *testing.B) { for _, n := range benchSizes { if isRaceBuilder && n > 1e3 { continue @@ -136,6 +136,8 @@ func BenchmarkAdd10VV(b *testing.B) { } } +// TODO(db47h): complete port of the tests + // func BenchmarkSubVV(b *testing.B) { // for _, n := range benchSizes { // if isRaceBuilder && n > 1e3 { @@ -387,7 +389,7 @@ func BenchmarkAdd10VV(b *testing.B) { // } // } -// // TODO(gri) mulAddVWW and divWVW are symmetric operations but +// // TODO(db47h) mulAddVWW and divWVW are symmetric operations but // // their signature is not symmetric. Try to unify. // type funWVW func(z []Word, xn Word, x []Word, y Word) (r Word) @@ -448,7 +450,7 @@ func BenchmarkAdd10VV(b *testing.B) { // x, y, c Word // q, r Word // }{ -// // TODO(agl): These will only work on 64-bit platforms. +// // TODO(db47h): These will only work on 64-bit platforms. // // {15064310297182388543, 0xe7df04d2d35d5d80, 13537600649892366549, 13644450054494335067, 10832252001440893781}, // // {15064310297182388543, 0xdab2f18048baa68d, 13644450054494335067, 12869334219691522700, 14233854684711418382}, // {_M, _M, 0, _M - 1, 1}, diff --git a/dec_test.go b/dec_test.go index eba33b6..a22f622 100644 --- a/dec_test.go +++ b/dec_test.go @@ -621,7 +621,7 @@ func TestDecFibo(t *testing.T) { } } -func BenchmarkFibo(b *testing.B) { +func BenchmarkDecFibo(b *testing.B) { for i := 0; i < b.N; i++ { decFibo(1e0) decFibo(1e1) @@ -830,7 +830,7 @@ func TestGoIssue37499(t *testing.T) { } // TODO(bd47h): move this to decimal_test -func benchmarkDiv(b *testing.B, aSize, bSize int) { +func benchmarkDecimalDiv(b *testing.B, aSize, bSize int) { aa := rndDec1(aSize) bb := rndDec1(bSize) if aa.cmp(bb) < 0 { @@ -845,15 +845,17 @@ func benchmarkDiv(b *testing.B, aSize, bSize int) { } } -func BenchmarkDiv(b *testing.B) { +func BenchmarkDecimalDiv(b *testing.B) { sizes := []int{ 10, 20, 50, 100, 200, 500, 1000, - 1e4, 1e5, 1e6, 1e7, + 1e4, + // TODO(db47h): enable these after optimizing + // 1e5, 1e6, 1e7, } for _, i := range sizes { j := 2 * i b.Run(fmt.Sprintf("%d/%d", j, i), func(b *testing.B) { - benchmarkDiv(b, j, i) + benchmarkDecimalDiv(b, j, i) }) } } diff --git a/decimal_test.go b/decimal_test.go index bfd16ef..2d6b570 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -29,7 +29,7 @@ var intData = []struct { {"0", 0, 0, nil, _DW, 0}, } -func TestDnorm(t *testing.T) { +func TestDecimal_dnorm(t *testing.T) { for i := 0; i < 10000; i++ { again: w := uint(rand.Uint64()) % _DB @@ -95,7 +95,7 @@ func TestDecimal_SetString(t *testing.T) { } } -func BenchmarkDnorm(b *testing.B) { +func BenchmarkDecimal_dnorm(b *testing.B) { d := dec(nil).make(1000) for i := range d { d[i] = Word(rand.Uint64()) % _DB diff --git a/stdlib.go b/stdlib.go index 724096d..b45bd2d 100644 --- a/stdlib.go +++ b/stdlib.go @@ -240,3 +240,34 @@ func (err ErrNaN) Error() string { } type nat []Word + +// Operands that are shorter than karatsubaThreshold are multiplied using +// "grade school" multiplication; for longer operands the Karatsuba algorithm +// is used. +const karatsubaThreshold = 40 // computed by calibrate_test.go + +// karatsubaLen computes an approximation to the maximum k <= n such that +// k = p/10**i for a number p <= threshold and an i >= 0. Thus, the +// result is the largest number that can be divided repeatedly by 10 before +// becoming about the value of threshold. +func karatsubaLen(n, threshold int) int { + i := uint(0) + for n > threshold { + n >>= 1 + i++ + } + return n << i +} + +func max(x, y int) int { + if x > y { + return x + } + return y +} + +// Operands that are shorter than basicSqrThreshold are squared using +// "grade school" multiplication; for operands longer than karatsubaSqrThreshold +// we use the Karatsuba algorithm optimized for x == y. +var basicSqrThreshold = 20 // computed by calibrate_test.go +var karatsubaSqrThreshold = 260 // computed by calibrate_test.go diff --git a/stdlib_test.go b/stdlib_test.go deleted file mode 100644 index 5f26805..0000000 --- a/stdlib_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package decimal - -import "testing" - -func TestMaxBase(t *testing.T) { - if MaxBase != len(digits) { - t.Fatalf("%d != %d", MaxBase, len(digits)) - } -}