-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathProgram.cs
152 lines (121 loc) · 4.51 KB
/
Program.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
using GGMLSharp;
using ModelLoader;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using static GGMLSharp.Structs;
namespace MNIST_CPU
{
internal class Program
{
static void Main(string[] args)
{
Model model = LoadModel(@".\Assets\mnist_model.state_dict");
byte[] bytes = File.ReadAllBytes(@".\Assets\image.raw");
Console.WriteLine("The image is:");
for (int i = 0; i < 28; i++)
{
for (int j = 0; j < 28; j++)
{
Console.Write(bytes[i * 28 + j] > 200 ? " " : "*");
}
Console.WriteLine();
}
float[] digit = new float[28 * 28];
for (int i = 0; i < bytes.Length; i++)
{
digit[i] = bytes[i] / 255.0f;
}
int prediction = Eval(model, digit);
Console.WriteLine("Prediction: {0}", prediction);
Console.ReadKey();
}
private static int Eval(Model model, float[] digit)
{
model.input.SetBackend(digit);
// calculate the temporaly memory required to compute
model.allocr = new SafeGGmlGraphAllocr(model.buffer.BufferType);
// create the worst case graph for memory usage estimation
BuildGraph(model);
model.graph.Reserve(model.allocr);
ulong mem_size = model.allocr.GetBufferSize(0);
Console.WriteLine($"compute buffer size: {mem_size / 1024.0} KB");
SafeGGmlTensor probs = Compute(model);
byte[] data = probs.GetBackend();
List<float> probsList = DataConverter.ConvertToFloats(data).ToList();
int prediction = probsList.IndexOf(probsList.Max());
model.context.Free();
return prediction;
}
public class Model
{
public SafeGGmlTensor input;
public SafeGGmlTensor fc2Weight;
public SafeGGmlTensor fc2Bias;
public SafeGGmlTensor fc1Weight;
public SafeGGmlTensor fc1Bias;
public SafeGGmlContext context;
public SafeGGmlBackend backend;
public SafeGGmlGraph graph;
public SafeGGmlBackendBuffer buffer;
public SafeGGmlGraphAllocr allocr;
}
public static Model LoadModel(string path)
{
PickleLoader pickleLoader = new PickleLoader();
List<Tensor> tensors = pickleLoader.ReadTensorsInfoFromFile(path);
Model model = new Model();
if (SafeGGmlBackend.HasCuda)
{
model.backend = SafeGGmlBackend.CudaInit(); // init device 0
}
else
{
model.backend = SafeGGmlBackend.CpuInit();
}
if (model.backend == null)
{
Console.WriteLine("ggml_backend_cuda_init() failed.");
Console.WriteLine("we while use ggml_backend_cpu_init() instead.");
// if there aren't GPU Backends fallback to CPU backend
model.backend = SafeGGmlBackend.CpuInit();
}
model.context = new SafeGGmlContext(IntPtr.Zero, NoAllocateMemory: true);
model.input = model.context.NewTensor1d(GGmlType.GGML_TYPE_F32, 28 * 28);
var t = tensors.Find(a => a.Name == "fc1.weight");
model.fc1Weight = model.context.NewTensor(t.Type, t.Shape.ToArray());
t = tensors.Find(a => a.Name == "fc1.bias");
model.fc1Bias = model.context.NewTensor(t.Type, t.Shape.ToArray());
t = tensors.Find(a => a.Name == "fc2.weight");
model.fc2Weight = model.context.NewTensor(t.Type, t.Shape.ToArray());
t = tensors.Find(a => a.Name == "fc2.bias");
model.fc2Bias = model.context.NewTensor(t.Type, t.Shape.ToArray());
model.buffer = model.context.BackendAllocContextTensors(model.backend);
model.fc1Weight.SetBackend(pickleLoader.ReadByteFromFile(tensors.First(a => a.Name == "fc1.weight")));
model.fc1Bias.SetBackend(pickleLoader.ReadByteFromFile(tensors.First(a => a.Name == "fc1.bias")));
model.fc2Weight.SetBackend(pickleLoader.ReadByteFromFile(tensors.First(a => a.Name == "fc2.weight")));
model.fc2Bias.SetBackend(pickleLoader.ReadByteFromFile(tensors.First(a => a.Name == "fc2.bias")));
return model;
}
private static void BuildGraph(Model model)
{
model.graph = model.context.NewGraph();
SafeGGmlTensor re = model.context.Linear(model.input, model.context.Reshape2d(model.fc1Weight, model.fc1Weight.Shape[1], model.fc1Weight.Shape[0]), model.fc1Bias);
re = model.context.Relu(re);
re = model.context.Linear(re, model.context.Reshape2d(model.fc2Weight, model.fc2Weight.Shape[1], model.fc2Weight.Shape[0]), model.fc2Bias);
re = model.context.SoftMax(re);
re.Name = "probs";
model.graph.BuildForwardExpend(re);
}
// compute with backend
private static SafeGGmlTensor Compute(Model model)
{
// allocate tensors
model.graph.GraphAllocate(model.allocr);
model.graph.BackendCompute(model.backend);
// in this case, the output tensor is the last one in the graph
return model.graph.Nodes[model.graph.NodeCount - 1];
}
}
}