From 8b88c09fa5b5c7d3278f30d7dc75f9299bfeb0ec Mon Sep 17 00:00:00 2001 From: Denis Bernard Date: Sun, 3 May 2020 18:29:43 +0200 Subject: [PATCH] Implement left-shifted mantissa --- arith_dec.go | 85 ++++++++++++++++--- dec.go | 70 +++++++++------- dec_conv.go | 7 ++ dec_test.go | 148 +++++++++++++++++++++++++++++----- decimal.go | 118 ++++++++++++++++++--------- decconv.go => decimal_conv.go | 119 ++++++++++++++++++++++++++- decimal_test.go | 54 ++++++++++++- stdlib.go | 112 ++++++++++++++++++++++++- 8 files changed, 609 insertions(+), 104 deletions(-) create mode 100644 dec_conv.go rename decconv.go => decimal_conv.go (64%) diff --git a/arith_dec.go b/arith_dec.go index 44d91ec..406d92a 100644 --- a/arith_dec.go +++ b/arith_dec.go @@ -51,17 +51,80 @@ func shr10VU(z, x dec, s uint) (r Word) { return r } -// The resulting carry c is either 0 or 1. -func add10VW(z, x []Word, y Word) (c Word) { - c = y - for i := 0; i < len(z) && i < len(x); i++ { - zi, cc := bits.Add(uint(x[i]), uint(c), 0) - if zi >= _BD { - zi -= _BD - c = 1 +func decTrailingZeros(n uint) uint { + if bits.UintSize == 32 { + return dec32TrailingZeros(n) + } + return dec64TrailingZeros(uint64(n)) +} + +func dec32TrailingZeros(n uint) uint { + var d uint + if n%100000000 == 0 { + n /= 100000000 + d += 8 + } + if n%10000 == 0 { + n /= 10000 + d += 4 + } + if n%100 == 0 { + n /= 100 + d += 2 + } + if n%10 == 0 { + d += 1 + } + return d +} + +func dec64TrailingZeros(n uint64) uint { + var d uint + if n%10000000000000000 == 0 { + n /= 10000000000000000 + d += 16 + } + if n%100000000 == 0 { + n /= 100000000 + d += 8 + } + if n%10000 == 0 { + n /= 10000 + d += 4 + } + if n%100 == 0 { + n /= 100 + d += 2 + } + if n%10 == 0 { + d += 1 + } + return d +} + +// addW adds y to x. The resulting carry c is either 0 or 1. +func add10VW(z, x dec, y Word) (c Word) { + s := x[0] + y + if (s < y) || s >= _BD { + z[0] = s - _BD + c = 1 + } else { + z[0] = s + } + // propagate carry + for i := 1; i < len(z) && i < len(x); i++ { + s = x[i] + c + if s == _BD { + z[i] = 0 + continue + } + // c = 0 from this point + z[i] = s + // copy remaining digits if not adding in-place + if !same(z, x) { + copy(z[i+1:], x[i+1:]) } - z[i] = Word(zi) - c = Word(cc) + return 0 } - return + return c } diff --git a/dec.go b/dec.go index ea73717..f0f8476 100644 --- a/dec.go +++ b/dec.go @@ -2,6 +2,7 @@ package decimal import ( "math/big" + "math/bits" "sync" ) @@ -42,7 +43,7 @@ func (z dec) norm() (dec, uint) { } z = z[:i] if len(z) == 0 { - return z, ls + return z, 0 } // partial shift if s := _WD - mag(uint(z[len(z)-1])); s != 0 { @@ -89,31 +90,24 @@ func (z dec) shr10(s uint) (d dec, r Word, tnz bool) { return z, r, tnz || m != 0 } -func (x dec) digit(n uint) uint { - n, m := n/_WD, n%_WD - return (uint(x[n]) / pow10(m)) % 10 -} - func (x dec) digits() uint { - // const H = 9 - // const P = 1000000000 for i, w := range x { if w != 0 { - // TODO(db47h): is there a way to optimize this? - var d uint - // if w%P == 0 { - // w /= P - // d += H - // } - for ; w%10 != 0; w /= 10 { - d++ - } - return uint(len(x)-i)*_WD - d + return uint(len(x)-i)*_WD - decTrailingZeros(uint(w)) } } return 0 } +func (x dec) digit(i uint) uint { + j, i := bits.Div(0, i, _WD) + if j >= uint(len(x)) { + return 0 + } + // 0 <= j < len(x) + return (uint(x[j]) / pow10(i)) % 10 +} + func (z dec) set(x dec) dec { z = z.make(len(x)) copy(z, x) @@ -134,19 +128,39 @@ func (z dec) make(n int) dec { return make(dec, n, n+e) } -func (z dec) setInt(x *big.Int) dec { +// setInt sets z such that z*10**exp = x with 0 < z <= 1. +// Returns z and exp. +func (z dec) setInt(x *big.Int) (dec, uint) { b := new(big.Int).Set(x).Bits() - n := len(b) - i := 0 - for ; n > 0; i++ { - z[i] = Word(divWVW_g(b, 0, b, big.Word(_BD))) - n = len(b) - for n > 0 && b[n-1] == 0 { - n-- + var i int + for i = 0; i < len(z) && len(b) > 0; i++ { + z[i] = Word(divWVW(b, 0, b, big.Word(_BD))) + } + z = z[:i] + z, s := z.norm() + return z, uint(i)*_WD - s +} + +// sticky returns 1 if there's a non zero digit within the +// i least significant digits, otherwise it returns 0. +func (x dec) sticky(i uint) uint { + j, i := bits.Div(0, i, _WD) + if j >= uint(len(x)) { + if len(x) == 0 { + return 0 + } + return 1 + } + // 0 <= j < len(x) + for _, x := range x[:j] { + if x != 0 { + return 1 } - b = b[0:n] } - return z[:i] + if uint(x[j])%pow10(i) != 0 { + return 1 + } + return 0 } // getDec returns a *dec of len n. The contents may not be zero. diff --git a/dec_conv.go b/dec_conv.go new file mode 100644 index 0000000..6a188e8 --- /dev/null +++ b/dec_conv.go @@ -0,0 +1,7 @@ +package decimal + +import "io" + +func (z dec) scan(r io.ByteScanner, base int, fracOk bool) (res dec, b, count int, err error) { + panic("not implemented") +} diff --git a/dec_test.go b/dec_test.go index 237fcdb..e088251 100644 --- a/dec_test.go +++ b/dec_test.go @@ -1,14 +1,16 @@ package decimal import ( + "bytes" "math/big" "math/bits" "math/rand" + "strconv" "testing" "time" ) -func Test_decNorm(t *testing.T) { +func Test_dec_norm(t *testing.T) { rand.Seed(time.Now().UnixNano()) for i := 0; i < 10000; i++ { w := uint(rand.Uint64()) % _BD @@ -30,6 +32,25 @@ func Test_decNorm(t *testing.T) { } } +func Test_dec_digits(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + for i := 0; i < 10000; i++ { + again: + w := uint(rand.Uint64()) % _BD + // ignore anything divisible by ten since mag(10) = 2 but dec{100000000...}.digits() = 1 + if w%10 == 0 { + goto again + } + e := uint(rand.Intn(_WD + 1)) + h, l := bits.Mul(w, pow10(e)) + h, l = bits.Div(h, l, _BD) + d, _ := dec{Word(l), Word(h)}.norm() + if d.digits() != mag(w) { + t.Fatalf("dec{%d}.digits() = %d, expected %d", d[0], d.digits(), mag(w)) + } + } +} + func Test_mag(t *testing.T) { rand.Seed(time.Now().UnixNano()) for i := 0; i < 10000; i++ { @@ -44,12 +65,97 @@ func Test_mag(t *testing.T) { } } -func Test_decSetInt(t *testing.T) { - b, _ := new(big.Int).SetString("123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890", 0) - prec := uint32(float64(b.BitLen())/ln2_10) + 1 - d := dec{}.make((int(prec) + _WD - 1) / _WD) - d = d.setInt(b) - t.Log(d, len(d)) +func Test_dec_setInt(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + for i := 0; i < 1000; i++ { + ns := make([]byte, rand.Intn(100)+1) + for i := 0; i < len(ns); i++ { + ns[i] = '0' + byte(rand.Intn(10)) + } + b, _ := new(big.Int).SetString(string(ns), 10) + // remove trailing 0s + ns = bytes.TrimLeft(ns, "0") + prec := uint32(float64(b.BitLen())/ln2_10) + 1 + d, exp := dec{}.make((int(prec) + _WD - 1) / _WD).setInt(b) + if exp != uint(len(ns)) { + t.Fatalf("%s -> %v. Expected exponent %d, got %d.", ns, d, len(ns), exp) + } + b2 := new(big.Int) + bd := new(big.Int).SetUint64(_BD) + x := new(big.Int) + for i := len(d) - 1; i >= 0; i-- { + b2.Mul(b2, bd).Add(b2, x.SetUint64(uint64(d[i]))) + } + shr := len(d)*_WD - int(exp) + if shr > 0 { + b2.Div(b2, x.SetUint64(uint64(pow10(uint(shr))))) + } else { + b2.Mul(b2, x.SetUint64(uint64(pow10(uint(-shr))))) + } + if b.Cmp(b2) != 0 { + t.Fatalf("Got %s -> %v x 10**%d. Bad conversion back to Int: %s", b, d, exp, b2) + } + } + b, _ := new(big.Int).SetString("12345678901234567890000000000000000000", 0) + d, exp := dec{}.make(3).setInt(b) + t.Log(d, exp) +} + +func Test_add10VW(t *testing.T) { + td := []struct { + i dec + x Word + o dec + c Word + s uint + }{ + {dec{_BD - 2, _BD - 1}, 2, nil, 1, 0}, + {dec{_BD - 2, _BD - 1}, 1, dec{_BD - 1, _BD - 1}, 0, 0}, + {dec{_BD - 2, _BD - 2}, 2, dec{_BD - 1}, 0, 0}, + } + for i, d := range td { + t.Run(strconv.Itoa(i), func(t *testing.T) { + z := d.i + c := add10VW(z, z, d.x) + z, s := z.norm() + ok := true + if len(z) != len(d.o) { + ok = false + } else { + for i := 0; i < len(z) && i < len(d.o); i++ { + if z[i] != d.o[i] { + ok = false + } + } + } + if !ok || s != d.s || c != d.c { + t.Fatalf("addW failed: expected z = %v, s = %d, c = %d, got d = %v, s = %v, c = %v", d.o, d.s, d.c, z, s, c) + } + + }) + } +} + +func TestDec_digit(t *testing.T) { + data := []struct { + d dec + n uint + r uint + }{ + {dec{123}, 0, 3}, + {dec{123}, 2, 1}, + {dec{123}, 3, 0}, + {dec{0, 1234567891234567891}, 37, 1}, + {dec{0, 1234567891234567891}, 36, 2}, + {dec{0, 1234567891234567891}, 38, 0}, + } + for di, d := range data { + t.Run(strconv.Itoa(di), func(t *testing.T) { + if dig := d.d.digit(d.n); dig != d.r { + t.Fatalf("%v.digit(%d) = %d, expected %d", d.d, d.n, dig, d.r) + } + }) + } } var ( @@ -57,15 +163,15 @@ var ( benchU uint ) -func Benchmark_decNorm(b *testing.B) { +func Benchmark_dec_norm(b *testing.B) { rand.Seed(0xdeadbeefbadf00d) - d := dec{}.make(2)[:0] + d := dec{}.make(10000) + for i := range d { + d[i] = Word(rand.Uint64()) % _BD + } for i := 0; i < b.N; i++ { - w := uint(rand.Uint64()) % _BD - e := uint(rand.Intn(_WD)) - h, l := w/pow10(_WD-e), (w%pow10(_WD-e))*pow10(e) - d = d.make(2) - d[0], d[1] = Word(l), Word(h) + d[0] = Word(rand.Uint64()) % _BD + d[len(d)-1] = Word(rand.Uint64()) % _BD benchD, benchU = d.norm() } } @@ -77,15 +183,15 @@ func Benchmark_mag(b *testing.B) { } } -func Benchmark_decDigits(b *testing.B) { +func Benchmark_dec_Digits(b *testing.B) { rand.Seed(0xdeadbeefbadf00d) - d := dec{}.make(1) + d := dec{}.make(10000) + for i := range d { + d[i] = Word(rand.Uint64()) % _BD + } for i := 0; i < b.N; i++ { - w := uint(rand.Uint64()) % _BD - e := uint(rand.Intn(_WD)) - h, l := bits.Mul(w, pow10(e)) - _, l = bits.Div(h, l, _BD) - d[0] = Word(l) + d[0] = Word(rand.Uint64()) % _BD + d[len(d)-1] = Word(rand.Uint64()) % _BD benchU = d.digits() } } diff --git a/decimal.go b/decimal.go index e8aafc6..e8cc6e6 100644 --- a/decimal.go +++ b/decimal.go @@ -10,7 +10,6 @@ type Decimal struct { mant dec exp int32 prec uint32 - dig uint32 mode RoundingMode acc Accuracy form form @@ -98,7 +97,7 @@ func (x *Decimal) MinPrec() uint { if x.form != finite { return 0 } - return uint(x.dig) + return x.mant.digits() } // Mode returns the rounding mode of x. @@ -146,7 +145,6 @@ func (z *Decimal) Set(x *Decimal) *Decimal { if z != x { z.form = x.form z.neg = x.neg - z.dig = x.dig if x.form == finite { z.exp = x.exp z.mant = z.mant.set(x.mant) @@ -154,7 +152,7 @@ func (z *Decimal) Set(x *Decimal) *Decimal { if z.prec == 0 { z.prec = x.prec } else if z.prec < x.prec { - z.round() + z.round(0) } } return z @@ -188,8 +186,7 @@ func (z *Decimal) SetInt(x *big.Int) *Decimal { z.prec = umax32(prec, _WD) } // TODO(db47h) truncating x could be more efficient if z.prec > 0 - // but small compared to the size of x, or if there - // are many trailing 0's. + // but small compared to the size of x, or if there are many trailing 0's. z.acc = Exact z.neg = x.Sign() < 0 if bits == 0 { @@ -198,10 +195,8 @@ func (z *Decimal) SetInt(x *big.Int) *Decimal { } // x != 0 exp := uint(0) - z.mant = z.mant.make((int(prec) + _WD - 1) / _WD).setInt(x) - z.mant, exp = z.mant.norm() - z.dig = uint32(z.mant.digits()) - z.setExpAndRound(int64(exp) + int64(z.dig)) + z.mant, exp = z.mant.make((int(prec) + _WD - 1) / _WD).setInt(x) + z.setExpAndRound(int64(exp), 0) return z } @@ -209,7 +204,7 @@ func (z *Decimal) SetInt64(x int64) *Decimal { panic("not implemented") } -func (z *Decimal) setExpAndRound(exp int64) { +func (z *Decimal) setExpAndRound(exp int64, sbit uint) { if exp < MinExp { // underflow z.acc = makeAcc(z.neg) @@ -226,7 +221,7 @@ func (z *Decimal) setExpAndRound(exp int64) { z.form = finite z.exp = int32(exp) - z.round() + z.round(sbit) } func (z *Decimal) SetMantExp(mant *Decimal, exp int) *Decimal { @@ -242,8 +237,35 @@ func (z *Decimal) SetMode(mode RoundingMode) *Decimal { return z } +// SetPrec sets z's precision to prec and returns the (possibly) rounded +// value of z. Rounding occurs according to z's rounding mode if the mantissa +// cannot be represented in prec digits without loss of precision. +// SetPrec(0) maps all finite values to ±0; infinite values remain unchanged. +// If prec > MaxPrec, it is set to MaxPrec. func (z *Decimal) SetPrec(prec uint) *Decimal { - panic("not implemented") + z.acc = Exact // optimistically assume no rounding is needed + + // special case + if prec == 0 { + z.prec = 0 + if z.form == finite { + // truncate z to 0 + z.acc = makeAcc(z.neg) + z.form = zero + } + return z + } + + // general case + if prec > MaxPrec { + prec = MaxPrec + } + old := z.prec + z.prec = uint32(prec) + if z.prec < old { + z.round(0) + } + return z } func (z *Decimal) SetRat(x *big.Rat) *Decimal { @@ -314,14 +336,8 @@ func (x *Decimal) validate() { if m == 0 { panic("nonzero finite number with empty mantissa") } - if x.mant[m-1] == 0 { - panic(fmt.Sprintf("last word of %s is zero", x.Text('e', 0))) - } - if x.mant[0]%10 == 0 { - panic(fmt.Sprintf("first word %d of %s is divisible by 10", x.mant[0], x.Text('e', 0))) - } - if d := uint32(x.mant.digits()); x.dig != d { - panic(fmt.Sprintf("digit count %d != real digit count %d for %s", x.dig, d, x.Text('e', 0))) + if msw := x.mant[m-1]; !(_BD/10 <= msw && msw < _BD) { + panic(fmt.Sprintf("last word of %s is not within [%d %d)", x.Text('e', 0), uint(_BD/10), uint(_BD))) } if x.prec == 0 { panic("zero precision finite number") @@ -335,8 +351,7 @@ func (x *Decimal) validate() { // CAUTION: The rounding modes ToNegativeInf, ToPositiveInf are affected by the // sign of z. For correct rounding, the sign of z must be set correctly before // calling round. -func (z *Decimal) round() { - var sbit bool +func (z *Decimal) round(sbit uint) { if debugDecimal { z.validate() } @@ -348,21 +363,39 @@ func (z *Decimal) round() { } // z.form == finite && len(z.mant) > 0 // m > 0 implies z.prec > 0 (checked by validate) - // m := uint32(len(z.mant)) // present mantissa length in words - if z.dig <= z.prec { + m := uint32(len(z.mant)) // present mantissa length in words + digits := m * _WD + if digits <= z.prec { // mantissa fits => nothing to do return } - // digits > z.prec - // r := uint(z.digits - z.prec - 1) - // rd := z.mant.digit(r) + // digits > z.prec: mantissa too large => round + r := uint(digits - z.prec - 1) // rounding digit position r >= 0 + rdigit := z.mant.digit(r) // rounding digit + + if sbit == 0 && (rdigit == 0 || z.mode == ToNearestEven) { + // The sticky bit is only needed for rounding ToNearestEven + // or when the rounding bit is zero. Avoid computation otherwise. + sbit = z.mant.sticky(r) + } + sbit &= 1 // be safe and ensure it's a single bit // cut off extra words + + n := (z.prec + (_WD - 1)) / _WD // mantissa length in words for desired precision + if m > n { + copy(z.mant, z.mant[m-n:]) // move n last words to front + z.mant = z.mant[:n] + } - var r Word - z.mant, r, sbit = z.mant.shr10(uint(z.dig - z.prec)) - z.dig = z.prec + // determine number of trailing zero digits (ntz) and compute lsd of mantissa's least-significant word + ntz := uint(n*_WD - z.prec) // 0 <= ntz < _W + lsd := pow10(ntz) - if r != 0 || sbit { + // round if result is inexact + if rdigit|sbit != 0 { + // Make rounding decision: The result mantissa is truncated ("rounded down") + // by default. Decide if we need to increment, or "round up", the (unsigned) + // mantissa. inc := false switch z.mode { case ToNegativeInf: @@ -370,9 +403,9 @@ func (z *Decimal) round() { case ToZero: // nothing to do case ToNearestEven: - inc = r > 5 || (r == 5 && (sbit || z.mant[0]&1 != 0)) + inc = rdigit > 5 || (rdigit == 5 && (sbit != 0 || z.mant.digit(ntz)&1 != 0)) case ToNearestAway: - inc = r >= 5 + inc = rdigit >= 5 case AwayFromZero: inc = true case ToPositiveInf: @@ -383,11 +416,24 @@ func (z *Decimal) round() { z.acc = makeAcc(inc != z.neg) if inc { // add 1 to mantissa - if add10VW(z.mant, z.mant, 1) != 0 { - + if add10VW(z.mant, z.mant, Word(lsd)) != 0 { + // mantissa overflow => adjust exponent + if z.exp >= MaxExp { + // exponent overflow + z.form = inf + return + } + z.exp++ + // mantissa overflow means that the mantissa before increment + // was all nines. In that case, the result is 1**(z.exp+1) + z.mant[n-1] = _BD / 10 } } } + + // zero out trailing digits in least-significant word + z.mant[0] -= z.mant[0] % Word(lsd) + if debugDecimal { z.validate() } diff --git a/decconv.go b/decimal_conv.go similarity index 64% rename from decconv.go rename to decimal_conv.go index cda28ef..0fa4dac 100644 --- a/decconv.go +++ b/decimal_conv.go @@ -25,7 +25,124 @@ func (z *Decimal) SetString(s string) (*Decimal, bool) { // as the implementation of Parse. It does not recognize ±Inf and does not expect // EOF at the end. func (z *Decimal) scan(r io.ByteScanner, base int) (f *Decimal, b int, err error) { - panic("not implemented") + prec := z.prec + if prec == 0 { + prec = _WD + } + + // A reasonable value in case of an error. + z.form = zero + + // sign + z.neg, err = scanSign(r) + if err != nil { + return + } + + // mantissa + var fcount int // fractional digit count; valid if <= 0 + z.mant, b, fcount, err = z.mant.scan(r, base, true) + if err != nil { + return + } + + // exponent + var exp int64 + var ebase int + exp, ebase, err = scanExponent(r, true, base == 0) + if err != nil { + return + } + + // special-case 0 + if len(z.mant) == 0 { + z.prec = prec + z.acc = Exact + z.form = zero + f = z + return + } + // len(z.mant) > 0 + + // The mantissa may have a radix point (fcount <= 0) and there + // may be a nonzero exponent exp. The radix point amounts to a + // division by b**(-fcount). An exponent means multiplication by + // ebase**exp. Finally, mantissa normalization (shift left) requires + // a correcting multiplication by 2**(-shiftcount). Multiplications + // are commutative, so we can apply them in any order as long as there + // is no loss of precision. We only have powers of 2 and 10, and + // we split powers of 10 into the product of the same powers of + // 2 and 5. This reduces the size of the multiplication factor + // needed for base-10 exponents. + + panic("adjust exponent needed") + + // normalize mantissa and determine initial exponent contributions + // exp2 := int64(len(z.mant))*_W - fnorm(z.mant) + exp2 := int64(0) + exp5 := int64(0) + + // determine binary or decimal exponent contribution of radix point + if fcount < 0 { + // The mantissa has a radix point ddd.dddd; and + // -fcount is the number of digits to the right + // of '.'. Adjust relevant exponent accordingly. + d := int64(fcount) + switch b { + case 10: + exp5 = d + fallthrough // 10**e == 5**e * 2**e + case 2: + exp2 += d + case 8: + exp2 += d * 3 // octal digits are 3 bits each + case 16: + exp2 += d * 4 // hexadecimal digits are 4 bits each + default: + panic("unexpected mantissa base") + } + // fcount consumed - not needed anymore + } + + // take actual exponent into account + switch ebase { + case 10: + exp5 += exp + fallthrough // see fallthrough above + case 2: + exp2 += exp + default: + panic("unexpected exponent base") + } + // exp consumed - not needed anymore + + // apply 2**exp2 + if MinExp <= exp2 && exp2 <= MaxExp { + z.prec = prec + z.form = finite + z.exp = int32(exp2) + f = z + } else { + err = fmt.Errorf("exponent overflow") + return + } + + if exp5 == 0 { + // no decimal exponent contribution + z.round(0) + return + } + // exp5 != 0 + + // // apply 5**exp5 + // p := new(Decimal).SetPrec(z.Prec() + _WD) // use more bits for p -- TODO(gri) what is the right number? + // if exp5 < 0 { + // z.Quo(z, p.pow5(uint64(-exp5))) + // } else { + // z.Mul(z, p.pow5(uint64(exp5))) + // } + + return } // Parse parses s which must contain a text representation of a floating- point diff --git a/decimal_test.go b/decimal_test.go index 1b8584b..02a6358 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -1,12 +1,58 @@ package decimal import ( + "math/big" + "reflect" + "strconv" "testing" ) +var intData = []struct { + s string + p uint + d dec + pr uint + e int32 +}{ + {"1234567890123456789_0123456789012345678_9012345678901234567_8901234567890123456_78901234567890", 0, + dec{7890123456789000000, 8901234567890123456, 9012345678901234567, 123456789012345678, 1234567890123456789}, + 90, 90}, + {"1235", 3, dec{1240000000000000000}, 3, 4}, + {"1245", 3, dec{1240000000000000000}, 3, 4}, + {"12451", 3, dec{1250000000000000000}, 3, 5}, +} + func TestDecimal_SetInt(t *testing.T) { - // b, _ := new(big.Int).SetString("123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890", 0) - // d := new(Decimal) - // d.SetInt(b) - // t.Log(d.prec, d.dig, d.exp) + for i, td := range intData { + t.Run(strconv.Itoa(i), func(t *testing.T) { + b, _ := new(big.Int).SetString(td.s, 0) + d := new(Decimal).SetMode(ToNearestEven).SetPrec(td.p).SetInt(b) + if !reflect.DeepEqual(td.d, d.mant) { + t.Fatalf("\nexpected mantissa %v\n got %v", td.d, d.mant) + } + if td.pr != d.Prec() { + t.Fatalf("\nexpected precision %v\n got %v", td.pr, d.Prec()) + } + if td.e != d.exp { + t.Fatalf("\nexpected exponent %v\n got %v", td.p, d.Prec()) + } + }) + } +} + +func TestDecimal_SetString(t *testing.T) { + for i, td := range intData { + t.Run(strconv.Itoa(i), func(t *testing.T) { + d, _ := new(Decimal).SetMode(ToNearestEven).SetPrec(td.p).SetString(td.s) + if !reflect.DeepEqual(td.d, d.mant) { + t.Fatalf("\nexpected mantissa %v\n got %v", td.d, d.mant) + } + if td.pr != d.Prec() { + t.Fatalf("\nexpected precision %v\n got %v", td.pr, d.Prec()) + } + if td.e != d.exp { + t.Fatalf("\nexpected exponent %v\n got %v", td.p, d.Prec()) + } + }) + } } diff --git a/stdlib.go b/stdlib.go index f824083..51b26ee 100644 --- a/stdlib.go +++ b/stdlib.go @@ -3,10 +3,13 @@ package decimal import ( + "errors" "fmt" + "io" "math" "math/big" "math/bits" + "strconv" ) // MaxBase is the largest number base accepted for string conversions. @@ -114,15 +117,118 @@ func umax32(x, y uint32) uint32 { } // q = (u1<<_W + u0 - r)/v -func divWW_g(u1, u0, v big.Word) (q, r big.Word) { +func divWW(u1, u0, v big.Word) (q, r big.Word) { qq, rr := bits.Div(uint(u1), uint(u0), uint(v)) return big.Word(qq), big.Word(rr) } -func divWVW_g(z []big.Word, xn big.Word, x []big.Word, y big.Word) (r big.Word) { +func divWVW(z []big.Word, xn big.Word, x []big.Word, y big.Word) (r big.Word) { r = xn for i := len(z) - 1; i >= 0; i-- { - z[i], r = divWW_g(r, x[i], y) + z[i], r = divWW(r, x[i], y) } return r } + +func same(x, y []Word) bool { + return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0] +} + +// scan errors +var ( + errNoDigits = errors.New("number has no digits") + errInvalSep = errors.New("'_' must separate successive digits") +) + +func scanSign(r io.ByteScanner) (neg bool, err error) { + var ch byte + if ch, err = r.ReadByte(); err != nil { + return false, err + } + switch ch { + case '-': + neg = true + case '+': + // nothing to do + default: + r.UnreadByte() + } + return +} + +func scanExponent(r io.ByteScanner, base2ok, sepOk bool) (exp int64, base int, err error) { + // one char look-ahead + ch, err := r.ReadByte() + if err != nil { + if err == io.EOF { + err = nil + } + return 0, 10, err + } + + // exponent char + switch ch { + case 'e', 'E': + base = 10 + case 'p', 'P': + if base2ok { + base = 2 + break // ok + } + fallthrough // binary exponent not permitted + default: + r.UnreadByte() // ch does not belong to exponent anymore + return 0, 10, nil + } + + // sign + var digits []byte + ch, err = r.ReadByte() + if err == nil && (ch == '+' || ch == '-') { + if ch == '-' { + digits = append(digits, '-') + } + ch, err = r.ReadByte() + } + + // prev encodes the previously seen char: it is one + // of '_', '0' (a digit), or '.' (anything else). A + // valid separator '_' may only occur after a digit. + prev := '.' + invalSep := false + + // exponent value + hasDigits := false + for err == nil { + if '0' <= ch && ch <= '9' { + digits = append(digits, ch) + prev = '0' + hasDigits = true + } else if ch == '_' && sepOk { + if prev != '0' { + invalSep = true + } + prev = '_' + } else { + r.UnreadByte() // ch does not belong to number anymore + break + } + ch, err = r.ReadByte() + } + + if err == io.EOF { + err = nil + } + if err == nil && !hasDigits { + err = errNoDigits + } + if err == nil { + exp, err = strconv.ParseInt(string(digits), 10, 64) + } + // other errors take precedence over invalid separators + if err == nil && (invalSep || prev == '_') { + err = errInvalSep + } + + return +}