Skip to content

Commit

Permalink
Decimal: improve Sqrt performance & increase Float/SetFloat precision
Browse files Browse the repository at this point in the history
Refactor (*Decimal).Sqrt to match the big.Float implementation (still
leaving out sqrtDirect). Compared to (*big.Float).Sqrt, the target
precision is slightly lower but precision increase in the Newton loop is
more conservative.

Instead of using a costly Float64 call to compute the initial guess, it
is now directly converted from the most significant Word of the mantissa
at a slight precision cost on 32 bits platforms (which causes about one
more loop iteration).

(*Decimal).Float and (*Decimal).SetFloat now use a higher precision when
applying the exponent of five.
  • Loading branch information
db47h committed May 21, 2020
1 parent 1a3f78b commit 320133a
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 70 deletions.
2 changes: 0 additions & 2 deletions dec.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"sync"
)

const debugDecimal = true

// dec is an unsigned integer x of the form
//
// x = x[n-1]*_BD^(n-1) + x[n-2]*_BD^(n-2) + ... + x[1]*_BD + x[0]
Expand Down
75 changes: 31 additions & 44 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"math/big"
)

const debugDecimal = true // enable for debugging

// DefaultDecimalPrec is the default minimum precision used when creating a new
// Decimal from a *big.Int, uint64, int64, or string. An uint64 requires up to
// 20 digits, which amounts to 2 x 19-digits Words (64 bits) or 3 x 9-digits
Expand Down Expand Up @@ -393,67 +395,49 @@ func (x *Decimal) Float(z *big.Float) *big.Float {
if z == nil {
z = new(big.Float).SetMode(big.RoundingMode(x.mode))
}
p := uint64(z.Prec())
p := uint(z.Prec())
if p == 0 {
p = uint64(max(int(math.Ceil(float64(x.prec)*log2_10)), 64))
z.SetPrec(0).SetPrec(uint(p))
p = uint(max(int(math.Ceil(float64(x.prec)*log2_10)), 64))
}

// clear z
z.SetPrec(0)

switch x.form {
case zero:
return z.SetPrec(0).SetPrec(uint(p))
return z.SetPrec(p)
case inf:
return z.SetInf(x.neg)
return z.SetInf(x.neg).SetPrec(p)
}

// increase precision
z.SetPrec(p + 1)

// big.Float has no SetBits. Need to use a temp Int.
var i big.Int
i.SetBits(decToNat(nil, x.mant))

exp := int64(x.exp) - int64(len(x.mant)*_DW)
m := len(x.mant) * _DW
exp := int64(x.exp) - int64(m)
z = z.SetInt(&i)
// z = x·2**(m - x.exp)·5**(m - x.exp)

// normalize mantissa and apply 2 exponent
// done in two steps since SetMantExponent takes an int.
z.SetMantExp(z, -m) // z = x·2**(-x.exp)·5**(m - x.exp)
z.SetMantExp(z, int(x.exp)) // z = x·5**(m - x.exp)

// now multiply/divide by 10**exp
// now multiply/divide by 5**exp
if exp != 0 {
t := new(big.Float).SetPrec(uint(p))
if exp < 0 {
if exp < MinExp {
// exponent overflow, convert mantissa first
l := uint64(len(x.mant) * _DW)
z.Quo(z, pow10Float(t, l))
exp += int64(l)
}
z.Quo(z, pow10Float(t, uint64(-exp)))
z.Quo(z, floatPow5(t, uint64(-exp)))
} else {
z.Mul(z, pow10Float(t, uint64(exp)))
}
}
return z
}

// pow10 sets z to 10**n and returns z.
// n must not be negative.
func pow10Float(z *big.Float, n uint64) *big.Float {
const m = uint64(len(pow10tab) - 1)
if n <= m {
return z.SetUint64(pow10tab[n])
}
// n > m

z.SetUint64(pow10tab[m])
n -= m

f := new(big.Float).SetPrec(z.Prec() + _W).SetUint64(10)

for n > 0 {
if n&1 != 0 {
z.Mul(z, f)
z.Mul(z, floatPow5(t, uint64(exp)))
}
f.Mul(f, f)
n >>= 1
}

return z
// round
return z.SetPrec(p)
}

// Float32 returns the float32 value nearest to x. If x is too small to be
Expand Down Expand Up @@ -886,6 +870,8 @@ func (z *Decimal) SetFloat(x *big.Float) *Decimal {
z.SetInt(i)
exp2 -= int64(fprec)
if exp2 != 0 {
// multiply / divide by 2**exp with increased precision
z.prec += 1
t := new(Decimal).SetPrec(uint(z.prec))
if exp2 < 0 {
if exp2 < MinExp {
Expand All @@ -898,6 +884,7 @@ func (z *Decimal) SetFloat(x *big.Float) *Decimal {
} else {
z = z.Mul(z, t.pow2(uint64(exp2)))
}
z.prec -= 1
}
z.round(0)
return z
Expand Down Expand Up @@ -930,17 +917,17 @@ func (z *Decimal) SetFloat64(x float64) *Decimal {
exp2 -= 64
z.mant, z.exp = z.mant.setUint64(1<<63 | math.Float64bits(fmant)<<11)
dnorm(z.mant)
// multiply / divide by 2**exp with increased precision
z.prec += 1
if exp2 != 0 {
// multiply / divide by 2**exp with increased precision
z.prec += 1
t := new(Decimal).SetPrec(uint(z.prec))
if exp2 < 0 {
z = z.Quo(z, t.pow2(uint64(-exp2)))
} else {
z = z.Mul(z, t.pow2(uint64(exp2)))
}
z.prec -= 1
}
z.prec -= 1
z.round(0)
return z
}
Expand Down
76 changes: 52 additions & 24 deletions decimal_sqrt.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package decimal

import (
"fmt"
"math"
)

// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

import "math"

var (
oneHalf = NewDecimal(0.5) // will be exact
three = NewDecimal(3.0)
oneHalf = NewDecimal(0.5) // must be exact
three = NewDecimal(3.0) // must be exact
)

// Sqrt sets z to the rounded square root of x, and returns it.
Expand Down Expand Up @@ -64,40 +67,65 @@ func (z *Decimal) Sqrt(x *Decimal) *Decimal {

// Unlike with big.Float, solving x² - z = 0 directly is faster only for
// very small precisions (<_DW/2).
//
// Solve 1/x² - z = 0 instead.
z.sqrtInverse(z)

// restore precision and re-attach halved exponent
return z.SetMantExp(z, b/2)
}

// Compute √x (to z.prec precision) by solving
// 1/t² - x = 0
// for t (using Newton's method), and then inverting.
func (z *Decimal) sqrtInverse(x *Decimal) {
if debugDecimal {
if oneHalf.acc != Exact {
panic(fmt.Sprintf("oneHalf is inexact (%v): %g", oneHalf.acc, oneHalf))
}
if three.acc != Exact {
panic(fmt.Sprintf("three is inexact (%v): %g", three.acc, three))
}
}

// Compute √x (to z.prec precision) by solving
// 1/t² - x = 0
// for t (using Newton's method), and then inverting.

// Compute initial guess for 1/√x
// xf needs only be "close enough", use a fast Decimal->Float64 conversion
xf := float64(x.mant[len(x.mant)-1]/10) / float64(pow10(uint(_DW-1-x.exp)))
t := newDecimal(z.prec).SetFloat64(1 / math.Sqrt(xf))
// t.prec = min(_DW, 17)
if _W == 32 {
t.prec = _DW
}
// t = initial guess for 1/√x

// let
// f(t) = 1/t² - x
// then
// g(t) = f(t)/f'(t) = -½t(1 - xt²)
// and the next guess is given by
// t2 = t - g(t) = ½t(3 - xt²)

u := newDecimal(prec)
v := newDecimal(prec)
xf, _ := x.Float64()
sqi := newDecimal(prec)
sqi.SetFloat64(1 / math.Sqrt(xf))
for prec := prec + _DW; sqi.prec < prec; {
sqi.prec *= 2
u.prec = sqi.prec
v.prec = sqi.prec
u.Mul(sqi, sqi) // u = sqi²
u.Mul(x, u) // = x.sqi²
v.Sub(three, u) // v = 3 - x.sqi²
u.Mul(sqi, v) // u = sqi(3 - x.sqi²)
sqi.Mul(u, oneHalf) // sqi = ½sqi(3 - x.sqi²)
u := newDecimal(z.prec)
v := newDecimal(z.prec)
for prec := z.prec + 2; t.prec < prec; {
// be less agressive than big.Float in precision increase
// |√z - t| < 10**(-2*t.prec + 2) <= 10**-prec
t.prec = t.prec*2 - 2
u.prec = t.prec
v.prec = t.prec
u.Mul(t, t) // u = t²
u.Mul(x, u) // = x.t²
v.Sub(three, u) // v = 3 - x.t²
u.Mul(t, v) // u = t(3 - x.t²)
t.Mul(u, oneHalf) // t = ½t(3 - x.t²)
}
// sqi = 1/√x
// t = 1/√x

// x/√x = √x
z.Mul(x, sqi)

// re-attach halved exponent
return z.SetMantExp(z, b/2)
z.Mul(z, t)
}

// newDecimal returns a new *Decimal with space for twice the given
Expand Down

0 comments on commit 320133a

Please sign in to comment.