Skip to content

Commit

Permalink
Implement division by constant _DB using multiplication
Browse files Browse the repository at this point in the history
Doubles the performance over bits.Div(n1, n0, _DB) on x86 and x86_64.
While it theoretically benefits all base-_DB arithmetic primitives,
it prevents them from being inlined, leading to diminishing returns
until we start manipulating numbers > 100 Words.
  • Loading branch information
db47h committed May 10, 2020
1 parent 8308a49 commit 5cfed50
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 98 deletions.
221 changes: 189 additions & 32 deletions dec.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,6 @@ import (

const debugDecimal = true

const (
// _W * log10(2) = decimal digits per word. 9 decimal digits per 32 bits
// word and 19 per 64 bits word.
_WD = _W * 30103 / 100000
// Decimal base for a word. 1e9 for 32 bits words and 1e19 for 64 bits
// words.
// We want this value to be a const. This is a dirty hack to avoid
// conditional compilation; it will break if bits.UintSize != 32 or 64
_BD = 9999999998000000000*(_WD/19) + 1000000000*(_WD/9)
_MD = _BD - 1
)

// dec is an unsigned integer x of the form
//
// x = x[n-1]*_BD^(n-1) + x[n-2]*_BD^(n-2) + ... + x[1]*_BD + x[0]
Expand Down Expand Up @@ -56,22 +44,22 @@ func (z dec) norm() dec {
// digits returns the number of digits of x.
func (x dec) digits() uint {
if i := len(x) - 1; i >= 0 {
return uint(i*_WD) + decDigits(uint(x[i]))
return uint(i*_DW) + decDigits(uint(x[i]))
}
return 0
}

func (x dec) ntz() uint {
for i, w := range x {
if w != 0 {
return uint(i)*_WD + decTrailingZeros(uint(w))
return uint(i)*_DW + decTrailingZeros(uint(w))
}
}
return 0
}

func (x dec) digit(i uint) uint {
j, i := bits.Div(0, i, _WD)
j, i := bits.Div(0, i, _DW)
if j >= uint(len(x)) {
return 0
}
Expand Down Expand Up @@ -110,13 +98,13 @@ func (z dec) setWord(x Word) dec {

func (z dec) setUint64(x uint64) (dec, int32) {
dig := int32(decDigits64(x))
if w := Word(x); uint64(w) == x && w < _BD {
if w := Word(x); uint64(w) == x && w < _DB {
return z.setWord(w), dig
}
// x could be a 2 to 3 words value
z = z.make(int(dig+_WD-1) / _WD)
z = z.make(int(dig+_DW-1) / _DW)
for i := 0; i < len(z); i++ {
hi, lo := bits.Div64(0, x, _BD)
hi, lo := bits.Div64(0, x, _DB)
z[i] = Word(lo)
x = hi
}
Expand All @@ -138,7 +126,7 @@ func (x dec) toNat(z []Word) []Word {
// r = zz & _B; zz = zz >> _W
var r Word
for j := len(zz) - 1; j >= 0; j-- {
zz[j], r = mulAddWWW(r, _BD, zz[j])
zz[j], r = mulAddWWW(r, _DB, zz[j])
}
zz = zz.norm()
z[i] = r
Expand All @@ -156,7 +144,7 @@ func (z dec) setInt(x *big.Int) dec {
b[i] = Word(bb[i])
}
for i := 0; i < len(z); i++ {
z[i] = divWVW(b, 0, b, _BD)
z[i] = divWVW(b, 0, b, _DB)
}
z = z.norm()
return z
Expand All @@ -165,7 +153,7 @@ func (z dec) setInt(x *big.Int) dec {
// sticky returns 1 if there's a non zero digit within the
// i least significant digits, otherwise it returns 0.
func (x dec) sticky(i uint) uint {
j, i := bits.Div(0, i, _WD)
j, i := bits.Div(0, i, _DW)
if j >= uint(len(x)) {
if len(x) == 0 {
return 0
Expand Down Expand Up @@ -327,6 +315,8 @@ func putDec(x *dec) {

var decPool sync.Pool

const divRecursiveThreshold = 100

// q = (uIn-r)/vIn, with 0 <= r < vIn
// Uses z as storage for q, and u as storage for r if possible.
// See Knuth, Volume 2, section 4.3.1, Algorithm D.
Expand All @@ -339,7 +329,7 @@ func (z dec) divLarge(u, uIn, vIn dec) (q, r dec) {
m := len(uIn)

// D1.
d := _BD / (vIn[n-1] + 1)
d := _DB / (vIn[n-1] + 1)
// do not modify vIn, it may be used by another goroutine simultaneously
vp := getDec(n)
v := *vp
Expand All @@ -356,11 +346,11 @@ func (z dec) divLarge(u, uIn, vIn dec) (q, r dec) {
q = z.make(m - n + 1)

// TODO(db47h): implement divRecursive
// if n < divRecursiveThreshold {
q.divBasic(u, v)
// } else {
// q.divRecursive(u, v)
// }
if n < divRecursiveThreshold {
q.divBasic(u, v)
} else {
q.divRecursive(u, v)
}
putDec(vp)

q = q.norm()
Expand All @@ -385,7 +375,7 @@ func (q dec) divBasic(u, v dec) {
vn1 := v[n-1]
for j := m; j >= 0; j-- {
// D3.
qhat := Word(_MD)
qhat := Word(_DMax)
var ujn Word
if j+n < len(u) {
ujn = u[j+n]
Expand Down Expand Up @@ -475,9 +465,9 @@ func (z dec) shl(x dec, s uint) dec {
}
// m > 0

n := m + int(s/_WD)
n := m + int(s/_DW)
z = z.make(n + 1)
z[n] = shl10VU(z[n-m:n], x, s%_WD)
z[n] = shl10VU(z[n-m:n], x, s%_DW)
z[0 : n-m].clear()

return z.norm()
Expand All @@ -495,14 +485,14 @@ func (z dec) shr(x dec, s uint) dec {
}

m := len(x)
n := m - int(s/_WD)
n := m - int(s/_DW)
if n <= 0 {
return z[:0]
}
// n > 0

z = z.make(n)
shr10VU(z, x[m-n:], s%_WD)
shr10VU(z, x[m-n:], s%_DW)

return z.norm()
}
Expand Down Expand Up @@ -780,3 +770,170 @@ func (z dec) expNN(x, y, m dec) dec {

return z.norm()
}

// divRecursive performs word-by-word division of u by v.
// The quotient is written in pre-allocated z.
// The remainder overwrites input u.
//
// Precondition:
// - len(z) >= len(u)-len(v)
//
// See Burnikel, Ziegler, "Fast Recursive Division", Algorithm 1 and 2.
func (z dec) divRecursive(u, v dec) {
// Recursion depth is less than 2 log2(len(v))
// Allocate a slice of temporaries to be reused across recursion.
recDepth := 2 * bits.Len(uint(len(v)))
// large enough to perform Karatsuba on operands as large as v
tmp := getDec(3 * len(v))
temps := make([]*dec, recDepth)
z.clear()
z.divRecursiveStep(u, v, 0, tmp, temps)
for _, n := range temps {
if n != nil {
putDec(n)
}
}
putDec(tmp)
}

// divRecursiveStep computes the division of u by v.
// - z must be large enough to hold the quotient
// - the quotient will overwrite z
// - the remainder will overwrite u
func (z dec) divRecursiveStep(u, v dec, depth int, tmp *dec, temps []*dec) {
u = u.norm()
v = v.norm()

if len(u) == 0 {
z.clear()
return
}
n := len(v)
if n < divRecursiveThreshold {
z.divBasic(u, v)
return
}
m := len(u) - n
if m < 0 {
return
}

// Produce the quotient by blocks of B words.
// Division by v (length n) is done using a length n/2 division
// and a length n/2 multiplication for each block. The final
// complexity is driven by multiplication complexity.
B := n / 2

// Allocate a nat for qhat below.
if temps[depth] == nil {
temps[depth] = getDec(n)
} else {
*temps[depth] = temps[depth].make(B + 1)
}

j := m
for j > B {
// Divide u[j-B:j+n] by vIn. Keep remainder in u
// for next block.
//
// The following property will be used (Lemma 2):
// if u = u1 << s + u0
// v = v1 << s + v0
// then floor(u1/v1) >= floor(u/v)
//
// Moreover, the difference is at most 2 if len(v1) >= len(u/v)
// We choose s = B-1 since len(v)-B >= B+1 >= len(u/v)
s := (B - 1)
// Except for the first step, the top bits are always
// a division remainder, so the quotient length is <= n.
uu := u[j-B:]

qhat := *temps[depth]
qhat.clear()
qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps)
qhat = qhat.norm()
// Adjust the quotient:
// u = u_h << s + u_l
// v = v_h << s + v_l
// u_h = q̂ v_h + rh
// u = q̂ (v - v_l) + rh << s + u_l
// After the above step, u contains a remainder:
// u = rh << s + u_l
// and we need to subtract q̂ v_l
//
// But it may be a bit too large, in which case q̂ needs to be smaller.
qhatv := tmp.make(3 * n)
qhatv.clear()
qhatv = qhatv.mul(qhat, v[:s])
for i := 0; i < 2; i++ {
e := qhatv.cmp(uu.norm())
if e <= 0 {
break
}
sub10VW(qhat, qhat, 1)
c := sub10VV(qhatv[:s], qhatv[:s], v[:s])
if len(qhatv) > s {
sub10VW(qhatv[s:], qhatv[s:], c)
}
decAddAt(uu[s:], v[s:], 0)
}
if qhatv.cmp(uu.norm()) > 0 {
panic("impossible")
}
c := sub10VV(uu[:len(qhatv)], uu[:len(qhatv)], qhatv)
if c > 0 {
sub10VW(uu[len(qhatv):], uu[len(qhatv):], c)
}
decAddAt(z, qhat, j-B)
j -= B
}

// Now u < (v<<B), compute lower bits in the same way.
// Choose shift = B-1 again.
s := B
qhat := *temps[depth]
qhat.clear()
qhat.divRecursiveStep(u[s:].norm(), v[s:], depth+1, tmp, temps)
qhat = qhat.norm()
qhatv := tmp.make(3 * n)
qhatv.clear()
qhatv = qhatv.mul(qhat, v[:s])
// Set the correct remainder as before.
for i := 0; i < 2; i++ {
if e := qhatv.cmp(u.norm()); e > 0 {
sub10VW(qhat, qhat, 1)
c := sub10VV(qhatv[:s], qhatv[:s], v[:s])
if len(qhatv) > s {
sub10VW(qhatv[s:], qhatv[s:], c)
}
decAddAt(u[s:], v[s:], 0)
}
}
if qhatv.cmp(u.norm()) > 0 {
panic("impossible")
}
c := sub10VV(u[0:len(qhatv)], u[0:len(qhatv)], qhatv)
if c > 0 {
c = sub10VW(u[len(qhatv):], u[len(qhatv):], c)
}
if c > 0 {
panic("impossible")
}

// Done!
decAddAt(z, qhat.norm(), 0)
}

// addAt implements z += x*10**(_WD*i); z must be long enough.
// (we don't use dec.add because we need z to stay the same
// slice, and we don't need to normalize z after each addition)
func decAddAt(z, x dec, i int) {
if n := len(x); n > 0 {
if c := add10VV(z[i:i+n], z[i:], x); c != 0 {
j := i + n
if j < len(z) {
add10VW(z[j:], z[j:], c)
}
}
}
}
Loading

0 comments on commit 5cfed50

Please sign in to comment.