-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmnist_dist_export.py
91 lines (76 loc) · 3.38 KB
/
mnist_dist_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Export given TensorFlow model.
The model is a pretrained "MNIST", which saved as TensorFlow model checkpoint. This program
simply uses TensorFlow SavedModel to
export the trained model with proper signatures that can be loaded by standard
tensorflow_model_server.
Usage: mnist_export.py [--model_version=y] [--checkpoint_dir=checkpoint_oss_path] export_dir
"""
import os
import sys
import tensorflow as tf
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils
from tensorflow.python.util import compat
from tensorflow.examples.tutorials.mnist import input_data as mnist_input_data
tf.app.flags.DEFINE_integer('model_version', 1, 'version number of the exported model.')
tf.app.flags.DEFINE_string('checkpoint_path', None, 'Checkpoints path.')
FLAGS = tf.app.flags.FLAGS
def main(_):
if len(sys.argv) < 2 or sys.argv[-1].startswith('-'):
print('Usage: mnist_dist_export.py '
'[--model_version=y] [--checkpoint_path=checkpoint_store_path] export_dir')
sys.exit(-1)
if FLAGS.model_version <= 0:
print 'Please specify a positive value for exported serveable version number.'
sys.exit(-1)
if not FLAGS.checkpoint_path:
print 'Please specify the correct path where checkpoints stored locally or in OSS.'
sys.exit(-1)
checkpoint_basename="model.ckpt"
default_meta_graph_suffix='.meta'
ckpt_path=os.path.join(FLAGS.checkpoint_path, checkpoint_basename + '-0')
meta_graph_file=ckpt_path + default_meta_graph_suffix
with tf.Session() as new_sess:
# with new_sess.graph.as_default():
# tf.reset_default_graph()
# new_sess.run(tf.initialize_all_variables())
new_saver = tf.train.import_meta_graph(meta_graph_file, clear_devices=True) #'/test/mnistoutput/ckpt.meta')
new_saver.restore(new_sess, ckpt_path) #'/test/mnistoutput/ckpt')
new_graph = tf.get_default_graph()
new_x = new_graph.get_tensor_by_name('input/x-input:0')
print(new_x)
new_y = new_graph.get_tensor_by_name('softmax_layer/y:0')
print(new_y)
# Export model
# WARNING(break-tutorial-inline-code): The following code snippet is
# in-lined in tutorials, please update tutorial documents accordingly
# whenever code changes.
export_path_base = sys.argv[-1]
export_path = os.path.join(
compat.as_bytes(export_path_base),
compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', export_path
builder = saved_model_builder.SavedModelBuilder(export_path)
# Build the signature_def_map.
tensor_info_x = utils.build_tensor_info(new_x)
tensor_info_y = utils.build_tensor_info(new_y)
prediction_signature = signature_def_utils.build_signature_def(
inputs={'images': tensor_info_x},
outputs={'scores': tensor_info_y},
method_name=signature_constants.PREDICT_METHOD_NAME)
legacy_init_op = tf.group(tf.initialize_all_tables(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
new_sess, [tag_constants.SERVING],
signature_def_map={
'predict_images':
prediction_signature,
},
legacy_init_op=legacy_init_op,
clear_devices=True)
builder.save()
print 'Done exporting!'
if __name__ == '__main__':
tf.app.run()