-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathautotune_model.py
156 lines (126 loc) · 5.38 KB
/
autotune_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import numpy as np
import tensorflow as tf
import argparse
import tvm
from tvm import autotvm
from tvm import relay
from tvm.relay import testing
import tvm.relay.testing.tf as tf_testing
from tvm.autotvm.tuner import XGBTuner, GATuner, GridSearchTuner
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
import riptide.models
from riptide.get_models import get_model
from riptide.binary.binary_layers import Config, DQuantize, XQuantize
os.environ["CUDA_VISIBLE_DEVICES"] = ''
parser = argparse.ArgumentParser()
parser.add_argument('--activation_bits', type=int, default=1, help='number of activation bits', required=False)
parser.add_argument('--model', type=str, choices=['vggnet', 'vgg11', 'resnet18', 'alexnet', 'darknet'], help='neural network model', required=True)
parser.add_argument('--trials', type=int, default=50, help='number of tuning trials', required=False)
parser.add_argument('--tuner', type=str, default='xgb', choices=['xgb', 'random', 'grid'], help='autotvm tuning algorithm.', required=False)
parser.add_argument('--log_file', type=str, default='log.log', help='logfile to store tuning results', required=False)
args = parser.parse_args()
model = args.model
activation_bits = args.activation_bits
trials = args.trials
tuner = args.tuner
log_file = args.log_file
config = Config(actQ=DQuantize, weightQ=XQuantize, bits=activation_bits, use_act=False, use_bn=False, use_maxpool=True)
with config:
model = get_model(model)
#model = riptide.models.normal_vggnet.vggnet()
# Init model shapes.
input_shape = [1, 224, 224, 3]
test_input = tf.keras.Input(shape=[224, 224, 3], batch_size=1, dtype='float32')
output = model(test_input)
print("Test run of model", output)
# Parse model to relay
net, params = relay.frontend.from_keras(model, shape={'input_1': input_shape})
num_threads = 12
os.environ["TVM_NUM_THREADS"] = str(num_threads)
target = "llvm"
ctx = tvm.cpu(0)
# Set up tuning options
tuning_option = {
'log_filename': log_file,
'tuner': tuner,
'early_stopping': None,
'n_trial': trials,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(number=10, repeat=1, min_repeat_ms=150),
),
}
def tune_kernels(tasks,
measure_option,
tuner='xgb',
n_trial=100,
early_stopping=None,
log_filename='tuning.log'):
for i, tsk in enumerate(tasks):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# Converting conv2d tasks to conv2d_NCHWc tasks. Do we actually want this?
op_name = tsk.workload[0]
input_shape = tsk.workload[1][0:-1]
kernel_shape = tsk.workload[2][0:-1]
input_channels = input_shape[1]
# Only can convert to NCHWc if input channels is divisible by 8.
#convertible = input_channels % 8 == 0
func_create = tsk.name
#if op_name == 'conv2d':
# func_create = 'topi_x86_conv2d_NCHWc'
#elif op_name == 'depthwise_conv2d_nchw':
# func_create = 'topi_x86_depthwise_conv2d_NCHWc_from_nchw'
print(func_create)
task = autotvm.task.create(func_create, args=tsk.args, target=target, template_key='direct')
task.workload = tsk.workload
# Create tuner.
if tuner == 'xgb' or tuner == 'xbg-rank':
tuner_obj = XGBTuner(task, loss_type='rank')
elif tuner == 'ga':
tuner_obj = GATuner(task, pop_size=50)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(task)
else:
raise ValueError("Invalid tunder: " + tuner)
# Do tuning.
n_trial = min(n_trial, len(task.config_space))
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(log_filename)])
# Launch jobs and evaluate performance.
def tune_and_evaluate(tuning_opt):
print("Extract tasks...")
global net, params, input_shape
tasks = autotvm.task.extract_from_program(
net,
target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense, relay.op.nn.bitserial_conv2d,
relay.op.nn.bitserial_dense))
# Run tuning tasks.
print("Tuning...")
tune_kernels(tasks, **tuning_opt)
# compile kernels with historgy best records.
with autotvm.apply_history_best(log_file):
print("Compile...")
with relay.build_config(opt_level=2):
graph, lib, params = relay.build_module.build(
net, target=target, params=params)
# Upload parameters to device.
ctx = tvm.cpu()
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype('float32'))
module = runtime.create(graph, lib, ctx)
module.set_input('input_1', data_tvm)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=10, repeat=1)
prof_res = np.array(ftimer().results) * 1000 # Convert to milliseconds
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
tune_and_evaluate(tuning_option)