Skip to content

Commit

Permalink
example/jit-train: added
Browse files Browse the repository at this point in the history
  • Loading branch information
sugarme committed Jan 2, 2021
1 parent 86c817b commit d6fb8d8
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

*.txt
*.json
*.pt
data/*.pt
*.ot

target/
Expand Down
98 changes: 94 additions & 4 deletions example/jit-train/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,25 @@ import (

"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
// ts "github.com/sugarme/gotch/tensor"
// "github.com/sugarme/gotch/vision"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
)

func main() {
runTrainAndSaveModel(gotch.CudaIfAvailable())
ds := vision.LoadMNISTDir("../../data/mnist")
dataset := &vision.Dataset{
TestImages: ds.TestImages.MustView([]int64{-1, 1, 28, 28}, true),
TrainImages: ds.TrainImages.MustView([]int64{-1, 1, 28, 28}, true),
TestLabels: ds.TestLabels,
TrainLabels: ds.TrainLabels,
}
device := gotch.CudaIfAvailable()

// runTrainAndSaveModel(dataset, device)
loadTrainedAndTestAcc(dataset, device)
}

func runTrainAndSaveModel(device gotch.Device) {
func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {

file := "./model.pt"
vs := nn.NewVarStore(device)
Expand All @@ -32,4 +42,84 @@ func runTrainAndSaveModel(device gotch.Device) {
for _, x := range namedTensors {
fmt.Println(x.Name)
}

trainable.SetTrain()
initialAcc := nn.BatchAccuracyForLogits(vs, trainable, ds.TestImages, ds.TestLabels, device, 1024)
fmt.Printf("Initial Accuracy: %0.4f\n", initialAcc)
bestAccuracy := initialAcc

opt, err := nn.DefaultAdamConfig().Build(vs, 1e-4)
if err != nil {
log.Fatal(err)
}
batchSize := 128
for epoch := 0; epoch < 20; epoch++ {

totalSize := ds.TrainImages.MustSize()[0]
samples := int(totalSize)
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)

batches := samples / batchSize
batchIndex := 0
var epocLoss *ts.Tensor
for i := 0; i < batches; i++ {
start := batchIndex * batchSize
size := batchSize
if samples-start < batchSize {
break
}
batchIndex += 1

// Indexing
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
bImages := imagesTs.Idx(narrowIndex)
bLabels := labelsTs.Idx(narrowIndex)

bImages = bImages.MustTo(vs.Device(), true)
bLabels = bLabels.MustTo(vs.Device(), true)

logits := trainable.ForwardT(bImages, true)
loss := logits.CrossEntropyForLogits(bLabels)

opt.BackwardStep(loss)

epocLoss = loss.MustShallowClone()
epocLoss.Detach_()

// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Float64Values()[0])

bImages.MustDrop()
bLabels.MustDrop()
}

testAccuracy := nn.BatchAccuracyForLogits(vs, trainable, ds.TestImages, ds.TestLabels, vs.Device(), 1024)
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
if testAccuracy > bestAccuracy {
bestAccuracy = testAccuracy
}

epocLoss.MustDrop()
imagesTs.MustDrop()
labelsTs.MustDrop()
}

err = trainable.Save("trained-model.pt")
if err != nil {
log.Fatal(err)
}
}

func loadTrainedAndTestAcc(ds *vision.Dataset, device gotch.Device) {
vs := nn.NewVarStore(device)
m, err := nn.TrainableCModuleLoad(vs.Root(), "./trained-model.pt")
if err != nil {
log.Fatal(err)
}

m.SetEval()
acc := nn.BatchAccuracyForLogits(vs, m, ds.TestImages, ds.TestLabels, device, 1024)

fmt.Printf("Accuracy: %0.4f\n", acc)
}
Binary file added example/jit-train/model.pt
Binary file not shown.
Binary file added example/jit-train/trained-model.pt
Binary file not shown.
Binary file added example/jit/model.pt
Binary file not shown.
10 changes: 10 additions & 0 deletions nn/jit.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,13 @@ func (m *TrainableCModule) ForwardT(x *ts.Tensor, train bool) *ts.Tensor {

return retVal
}

// SetTrain set TrainableCModule to train mode
func (m *TrainableCModule) SetTrain() {
m.Inner.SetTrain()
}

// SetEval set TrainableCModule to inference mode
func (m *TrainableCModule) SetEval() {
m.Inner.SetEval()
}
8 changes: 4 additions & 4 deletions tensor/jit.go
Original file line number Diff line number Diff line change
Expand Up @@ -1104,16 +1104,16 @@ func (cm *CModule) SetProfilingMode(b bool) {
}
}

// Train set CModule to train mode
func (cm *CModule) Train() {
// SetTrain set CModule to train mode
func (cm *CModule) SetTrain() {
lib.AtmTrain(cm.Cmodule)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}

// Eval set CModule to inference mode
func (cm *CModule) Eval() {
// SetEval set CModule to inference mode
func (cm *CModule) SetEval() {
lib.AtmEval(cm.Cmodule)
if err := TorchErr(); err != nil {
log.Fatal(err)
Expand Down

0 comments on commit d6fb8d8

Please sign in to comment.