Skip to content

Commit

Permalink
Decimal/dec: implement basic multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
db47h committed May 6, 2020
1 parent 7e163cc commit 9cf3940
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 5 deletions.
177 changes: 177 additions & 0 deletions dec.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ func (z dec) setWord(x Word) dec {
return z
}

func (z dec) setUint64(x uint64) (dec, int32) {
dig := int32(decDigits64(x))
if w := Word(x); uint64(w) == x {
return z.setWord(w), dig
}
// x could be a 2 to 3 words value
z = z.make(int(dig+_WD-1) / _WD)
for i := 0; i < len(z) && x != 0; i++ {
hi, lo := bits.Div64(0, x, _BD)
z[i] = Word(lo)
x = hi
}
return z, dig
}

// setInt sets z = x.mant
func (z dec) setInt(x *big.Int) dec {
bb := x.Bits()
Expand Down Expand Up @@ -272,3 +287,165 @@ func putDec(x *dec) {
}

var decPool sync.Pool

// Operands that are shorter than basicSqrThreshold are squared using
// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
// we use the Karatsuba algorithm optimized for x == y.
var decBasicSqrThreshold = 20 // computed by calibrate_test.go
var decKaratsubaSqrThreshold = 260 // computed by calibrate_test.go

// z = x*x
func (z dec) sqr(x dec) dec {
n := len(x)
switch {
case n == 0:
return z[:0]
case n == 1:
d := x[0]
z = z.make(2)
z[1], z[0] = mul10WW(d, d)
return z.norm()
}

if alias(z, x) {
z = nil // z is an alias for x - cannot reuse
}

// if n < decBasicSqrThreshold {
z = z.make(2 * n)
decBasicMul(z, x, x)
return z.norm()
// }
// TODO(db47h): implement basicSqr
// if n < decKaratsubaSqrThreshold {
// z = z.make(2 * n)
// basicSqr(z, x)
// return z.norm()
// }
// TODO(db47h): implement aratsuba algorithm
// Use Karatsuba multiplication optimized for x == y.
// The algorithm and layout of z are the same as for mul.

// z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2

// k := karatsubaLen(n, karatsubaSqrThreshold)

// x0 := x[0:k]
// z = z.make(max(6*k, 2*n))
// karatsubaSqr(z, x0) // z = x0^2
// z = z[0 : 2*n]
// z[2*k:].clear()

// if k < n {
// tp := getNat(2 * k)
// t := *tp
// x0 := x0.norm()
// x1 := x[k:]
// t = t.mul(x0, x1)
// addAt(z, t, k)
// addAt(z, t, k) // z = 2*x1*x0*b + x0^2
// t = t.sqr(x1)
// addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
// putNat(tp)
// }

// return z.norm()
}

// decBasicMul multiplies x and y and leaves the result in z.
// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
func decBasicMul(z, x, y dec) {
z[0 : len(x)+len(y)].clear() // initialize z
for i, d := range y {
if d != 0 {
z[len(x)+i] = addMul10VVW(z[i:i+len(x)], x, d)
}
}
}

func (z dec) mul(x, y dec) dec {
m := len(x)
n := len(y)

switch {
case m < n:
return z.mul(y, x)
case m == 0 || n == 0:
return z[:0]
case n == 1:
return z.mulAddWW(x, y[0], 0)
}
// m >= n > 1

// determine if z can be reused
if alias(z, x) || alias(z, y) {
z = nil // z is an alias for x or y - cannot reuse
}

// use basic multiplication if the numbers are small
// if n < karatsubaThreshold {
z = z.make(m + n)
decBasicMul(z, x, y)
return z.norm()
// }
// m >= n && n >= karatsubaThreshold && n >= 2

// // determine Karatsuba length k such that
// //
// // x = xh*b + x0 (0 <= x0 < b)
// // y = yh*b + y0 (0 <= y0 < b)
// // b = 1<<(_W*k) ("base" of digits xi, yi)
// //
// k := karatsubaLen(n, karatsubaThreshold)
// // k <= n

// // multiply x0 and y0 via Karatsuba
// x0 := x[0:k] // x0 is not normalized
// y0 := y[0:k] // y0 is not normalized
// z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
// karatsuba(z, x0, y0)
// z = z[0 : m+n] // z has final length but may be incomplete
// z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)

// // If xh != 0 or yh != 0, add the missing terms to z. For
// //
// // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
// // yh = y1*b (0 <= y1 < b)
// //
// // the missing terms are
// //
// // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
// //
// // since all the yi for i > 1 are 0 by choice of k: If any of them
// // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
// // be a larger valid threshold contradicting the assumption about k.
// //
// if k < n || m != n {
// tp := getNat(3 * k)
// t := *tp

// // add x0*y1*b
// x0 := x0.norm()
// y1 := y[k:] // y1 is normalized because y is
// t = t.mul(x0, y1) // update t so we don't lose t's underlying array
// addAt(z, t, k)

// // add xi*y0<<i, xi*y1*b<<(i+k)
// y0 := y0.norm()
// for i := k; i < len(x); i += k {
// xi := x[i:]
// if len(xi) > k {
// xi = xi[:k]
// }
// xi = xi.norm()
// t = t.mul(xi, y0)
// addAt(z, t, i)
// t = t.mul(xi, y1)
// addAt(z, t, i+k)
// }

// putNat(tp)
// }

// return z.norm()
}
40 changes: 40 additions & 0 deletions dec_arith.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ var maxDigits = [...]uint{
// In other words, n the number of digits required to represent n.
// Returns 0 for x == 0.
func decDigits(x uint) (n uint) {
if bits.UintSize == 32 {
return decDigits32(x)
}
return decDigits64(uint64(x))
}

func decDigits64(x uint64) (n uint) {
n = maxDigits[bits.Len64(x)]
if x < uint64(pow10(n-1)) {
n--
}
return n
}

func decDigits32(x uint) (n uint) {
n = maxDigits[bits.Len(x)]
if x < pow10(n-1) {
n--
Expand Down Expand Up @@ -188,3 +203,28 @@ func mulAdd10WWW(x, y, c Word) (z1, z0 Word) {
hi, lo = bits.Div(hi+cc, lo, _BD)
return Word(hi), Word(lo)
}

// z1<<_W + z0 = x*y
func mul10WW(x, y Word) (z1, z0 Word) {
hi, lo := bits.Mul(uint(x), uint(y))
hi, lo = bits.Div(hi, lo, _BD)
return Word(hi), Word(lo)
}

func add10WW(x, y Word) (s, c Word) {
r, cc := bits.Add(uint(x), uint(y), 0)
if r >= _BD {
r -= _BD
cc = 1
}
return Word(r), Word(cc)
}

func addMul10VVW(z, x []Word, y Word) (c Word) {
for i := 0; i < len(z) && i < len(x); i++ {
z1, z0 := mulAdd10WWW(x[i], y, z[i])
z[i], c = add10WW(z0, c)
c += z1
}
return
}
Loading

0 comments on commit 9cf3940

Please sign in to comment.