我有这段代码可以预热模型。
我得到saved_model.pb
,但variables
为空。
如何保存variables
?
import tensorflow as tf
import os
SAVE_PATH = 'C:/work/'
MODEL_NAME = 'test'
VERSION = 1
SERVE_PATH = 'C:/work/out/{}/{}'.format(MODEL_NAME, VERSION)
checkpoint = tf.train.latest_checkpoint(SAVE_PATH)
print(checkpoint)
tf.reset_default_graph()
with tf.Session() as sess:
# import the saved graph
saver = tf.train.import_meta_graph(checkpoint + '.meta')
# get the graph for this session
graph = tf.get_default_graph()
sess.run(tf.global_variables_initializer())
# get the tensors that we need
inputs = graph.get_tensor_by_name('ImageTensor:0')
predictions = graph.get_tensor_by_name('SemanticPredictions:0')
# create tensors info
print(inputs,predictions)
model_input = tf.saved_model.utils.build_tensor_info(inputs)
model_output = tf.saved_model.utils.build_tensor_info(predictions)
# build signature definition
signature_definition = tf.saved_model.signature_def_utils.build_signature_def(
inputs={'inputs': model_input},
outputs={'outputs': model_output},
method_name= tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder = tf.saved_model.builder.SavedModelBuilder(SERVE_PATH)
print(tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature_definition
})
# Save the model so we can serve it with a model server :)
builder.save()