-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathcnn.go
133 lines (108 loc) · 3.33 KB
/
cnn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package main
import (
"fmt"
"log"
"runtime"
"sync"
"time"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
"github.com/sugarme/gotch/ts"
"github.com/sugarme/gotch/vision"
)
const (
MnistDirCNN string = "../../data/mnist"
epochsCNN = 30
batchCNN = 256
// batchSize = 256
batchSize = 32
LrCNN = 3 * 1e-4
)
var mu sync.Mutex
type Net struct {
conv1 *nn.Conv2D
conv2 *nn.Conv2D
fc1 *nn.Linear
fc2 *nn.Linear
}
func newNet(vs *nn.Path) *Net {
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())
fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())
return &Net{
conv1,
conv2,
fc1,
fc2}
}
func (n *Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
outC1 := outView1.Apply(n.conv1)
outMP1 := outC1.MaxPool2DDefault(2, true)
outC2 := outMP1.Apply(n.conv2)
outMP2 := outC2.MaxPool2DDefault(2, true)
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
outFC1 := outView2.Apply(n.fc1)
outRelu := outFC1.MustRelu(false)
outDropout := ts.MustDropout(outRelu, 0.5, train)
return outDropout.Apply(n.fc2)
}
func runCNN1() {
var ds *vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
trainImages := ds.TrainImages.MustTo(device, false) //[60000, 784]
trainLabels := ds.TrainLabels.MustTo(device, false) // [60000, 784]
testImages := ds.TestImages.MustTo(device, false) // [10000, 784]
testLabels := ds.TestLabels.MustTo(device, false) // [10000, 784]
fmt.Printf("testImages: %v\n", testImages.MustSize())
fmt.Printf("testLabels: %v\n", testLabels.MustSize())
vs := nn.NewVarStore(device)
net := newNet(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
// opt, err := nn.DefaultSGDConfig().Build(vs, LrCNN)
if err != nil {
log.Fatal(err)
}
var bestAccuracy float64 = 0.0
startTime := time.Now()
for epoch := 0; epoch < epochsCNN; epoch++ {
totalSize := ds.TrainImages.MustSize()[0]
samples := int(totalSize)
// Shuffling
index := ts.MustRandperm(int64(totalSize), gotch.Int64, device)
imagesTs := trainImages.MustIndexSelect(0, index, false)
labelsTs := trainLabels.MustIndexSelect(0, index, false)
batches := samples / batchSize
batchIndex := 0
var epocLoss float64
for i := 0; i < batches; i++ {
start := batchIndex * batchSize
size := batchSize
if samples-start < batchSize {
break
}
batchIndex += 1
// Indexing
bImages := imagesTs.MustNarrow(0, int64(start), int64(size), false)
logits := net.ForwardT(bImages, true)
bLabels := labelsTs.MustNarrow(0, int64(start), int64(size), false)
loss := logits.CrossEntropyForLogits(bLabels)
loss = loss.MustSetRequiresGrad(true, true)
opt.BackwardStep(loss)
epocLoss = loss.Float64Values()[0]
runtime.GC()
}
ts.NoGrad(func() {
fmt.Printf("Start eval...")
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1000)
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss, testAccuracy*100.0)
if testAccuracy > bestAccuracy {
bestAccuracy = testAccuracy
}
})
}
fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
ts.CleanUp()
}