Skip to content

Commit

Permalink
refactor: Train err and new Test
Browse files Browse the repository at this point in the history
  • Loading branch information
lvisei committed Dec 15, 2020
1 parent 12fd1bb commit b501d03
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 9 deletions.
4 changes: 3 additions & 1 deletion examples/csv/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ func main() {
defer timeCost()("训练模型与插值生成网格图片总耗时")

ordinaryKriging := ordinarykriging.NewOrdinary(data["values"], data["x"], data["y"])
_ = ordinaryKriging.Train(ordinarykriging.Spherical, 0, 100)
if _, err := ordinaryKriging.Train(ordinarykriging.Exponential, 0, 100); err != nil {
log.Fatal(err)
}

_ = polygon
gridPlot(ordinaryKriging, polygon)
Expand Down
10 changes: 8 additions & 2 deletions examples/tinygo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"github.com/liuvigongzuoshi/go-kriging/ordinarykriging"
"log"
)

// tinygo build -o kriging.wasm -opt z -heap-size 2048M -target wasm ./main.go
Expand Down Expand Up @@ -1832,12 +1833,17 @@ func OrdinaryKrigingFunc() int {

func OrdinaryKriging(values, lons, lats []float64, model string, sigma2, alpha, width float64, polygon ordinarykriging.PolygonCoordinates) *ordinarykriging.GridMatrices {
ordinaryKriging := ordinarykriging.NewOrdinary(values, lons, lats)
_ = ordinaryKriging.Train(ordinarykriging.ModelType(model), sigma2, alpha)
if _, err := ordinaryKriging.Train(ordinarykriging.ModelType(model), sigma2, alpha); err != nil {
log.Fatal(err)
}
return ordinaryKriging.Grid(polygon, width)
}

func OrdinaryKrigingTrain(values, lons, lats []float64, model string, sigma2 float64, alpha float64) *ordinarykriging.Variogram {
ordinaryKriging := ordinarykriging.NewOrdinary(values, lons, lats)
variogram := ordinaryKriging.Train(ordinarykriging.ModelType(model), sigma2, alpha)
variogram, err := ordinaryKriging.Train(ordinarykriging.ModelType(model), sigma2, alpha)
if err != nil {
log.Fatal(err)
}
return variogram
}
9 changes: 7 additions & 2 deletions examples/wasm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,18 @@ func OrdinaryKrigingTrainFunc(this js.Value, args []js.Value) interface{} {

func RunOrdinaryKrigingTrain(values, lons, lats []float64, model string, sigma2 float64, alpha float64) *ordinarykriging.Variogram {
ordinaryKriging := ordinarykriging.NewOrdinary(values, lons, lats)
variogram := ordinaryKriging.Train(ordinarykriging.ModelType(model), sigma2, alpha)
variogram, err := ordinaryKriging.Train(ordinarykriging.ModelType(model), sigma2, alpha)
if err != nil {
log.Fatal(err)
}
return variogram
}

func RunOrdinaryKriging(values, lons, lats []float64, model string, sigma2, alpha, width float64, polygon ordinarykriging.PolygonCoordinates) *ordinarykriging.GridMatrices {
ordinaryKriging := ordinarykriging.NewOrdinary(values, lons, lats)
_ = ordinaryKriging.Train(ordinarykriging.ModelType(model), sigma2, alpha)
if _, err := ordinaryKriging.Train(ordinarykriging.ModelType(model), sigma2, alpha); err != nil {
log.Fatal(err)
}
return ordinaryKriging.Grid(polygon, width)
}

Expand Down
8 changes: 4 additions & 4 deletions ordinarykriging/ordinarykriging.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package ordinarykriging

import (
"errors"
"github.com/liuvigongzuoshi/go-kriging/canvas"
"image"
"image/color"
Expand Down Expand Up @@ -62,7 +63,7 @@ func krigingVariogramSpherical(h, nugget, range_, sill, A float64) float64 {
}

// Train using gaussian processes with bayesian priors
func (variogram *Variogram) Train(model ModelType, sigma2 float64, alpha float64) *Variogram {
func (variogram *Variogram) Train(model ModelType, sigma2 float64, alpha float64) (*Variogram, error) {
variogram.Nugget = 0.0
variogram.Range = 0.0
variogram.Sill = 0.0
Expand Down Expand Up @@ -147,8 +148,7 @@ func (variogram *Variogram) Train(model ModelType, sigma2 float64, alpha float64
k = 0
}
if l < 2 {
// Error: Not enough points
return variogram
return nil, errors.New("not enough points")
}
}

Expand Down Expand Up @@ -232,7 +232,7 @@ func (variogram *Variogram) Train(model ModelType, sigma2 float64, alpha float64
variogram.K = K
variogram.M = M

return variogram
return variogram, nil
}

// Predict model prediction
Expand Down
50 changes: 50 additions & 0 deletions ordinarykriging/ordinarykriging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,56 @@ func generateData(count int) (FloatList, FloatList, FloatList) {
return randomValues, randomLats, randomLons
}

func TestVariogram_Predict_100(t *testing.T) {
var randomValues, randomLats, randomLons FloatList
randomLons = append(randomLons, randFloats(-90, 90, 100)...)
randomLats = append(randomLats, randFloats(-180, 180, 100)...)
for _, lon := range randomLons {
var value float64
if lon > 0 {
value = 100
} else {
value = 0
}
randomValues = append(randomValues, value)
}
ordinaryKriging := ordinarykriging.NewOrdinary(randomValues, randomLats, randomLons)
if _, err := ordinaryKriging.Train(ordinarykriging.Exponential, 0, 10); err != nil {
t.Fatal("variogram is null", err)
}
if ordinaryKriging.Predict(180, 0) < 50 {
t.Fatal("unexpected result (<50)")
}
if ordinaryKriging.Predict(-180, 0) > 50 {
t.Fatal("unexpected result (>50)")
}
}

func TestVariogram_Predict_1000(t *testing.T) {
var randomValues, randomLats, randomLons FloatList
randomLons = append(randomLons, randFloats(-90, 90, 1000)...)
randomLats = append(randomLats, randFloats(-180, 180, 1000)...)
for _, lon := range randomLons {
var value float64
if lon > 0 {
value = 100
} else {
value = 0
}
randomValues = append(randomValues, value)
}
ordinaryKriging := ordinarykriging.NewOrdinary(randomValues, randomLats, randomLons)
if _, err := ordinaryKriging.Train(ordinarykriging.Exponential, 0, 10); err != nil {
t.Fatal("variogram is null", err)
}
if ordinaryKriging.Predict(180, 0) < 50 {
t.Fatal("unexpected result (<50)")
}
if ordinaryKriging.Predict(-180, 0) > 50 {
t.Fatal("unexpected result (>50)")
}
}

func TestVariogram_Plot(t *testing.T) {
ordinaryKriging := ordinarykriging.NewOrdinary(randomValues, randomLats, randomLons)
ordinaryKriging.Train(ordinarykriging.Exponential, 0, 100)
Expand Down

0 comments on commit b501d03

Please sign in to comment.