如何在tensorflow上加载和使用已保存的模型?

时间:2017-08-16 04:21:25

标签: tensorflow

我找到了两种在Tensorflow中保存模型的方法:tf.train.Saver()SavedModelBuilder。但是,在第二种方式加载后,我找不到使用模型的文档。

注意:我想使用SavedModelBuilder方式,因为我在Python中训练模型并将在服务时使用另一种语言(Go),似乎SavedModelBuilder是唯一的方法那种情况。

这适用于tf.train.Saver()(第一种方式):

model = tf.add(W * x, b, name="finalnode")

# save
saver = tf.train.Saver()
saver.save(sess, "/tmp/model")

# load
saver.restore(sess, "/tmp/model")

# IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT
# I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY.

model = graph.get_tensor_by_name("finalnode:0")
sess.run(model, {x: [5, 6, 7]})

tf.saved_model.builder.SavedModelBuilder()Readme中定义,但在使用tf.saved_model.loader.load(sess, [], export_dir)加载模型后),我找不到有关返回节点的文档(请参阅"finalnode" in上面的代码)

4 个答案:

答案 0 :(得分:9)

缺少的是signature

# Saving
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= {
        "model": tf.saved_model.signature_def_utils.predict_signature_def(
            inputs= {"x": x},
            outputs= {"finalnode": model})
        })
builder.save()

# loading
with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, ["tag"], export_dir)
    graph = tf.get_default_graph()
    x = graph.get_tensor_by_name("x:0")
    model = graph.get_tensor_by_name("finalnode:0")
    print(sess.run(model, {x: [5, 6, 7, 8]}))

答案 1 :(得分:3)

Tensorflow在不同语言中构建和使用模型的首选方法是tensorflow serving

现在,在您的情况下,您正在使用saver.save来保存模型。通过这种方式,它可以保存meta文件,ckpt文件和其他一些文件,以保存权重和网络信息,训练步骤等。这是您在训练时首选的保存方式。

如果您现在已完成培训,则应使用SavedModelBuildersaver.save保存的文件中冻结图表。此冻结图包含pb文件,包含所有网络和权重。

此冻结模型应该用于tensorflow serving提供服务,然后其他语言可以使用gRPC协议使用该模型。

整个过程在this优秀教程中进行了描述。

答案 2 :(得分:1)

以下是使用simple_save加载和还原/预测模型的代码段

#Save the model:
tf.saved_model.simple_save(sess, export_dir=saveModelPath,
                                   inputs={"inputImageBatch": X_train, "inputClassBatch": Y_train,
                                           "isTrainingBool": isTraining},
                                   outputs={"predictedClassBatch": predClass})

请注意,使用simple_save会设置某些默认值(可以在https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/simple_save.py上看到)

现在,恢复和使用输入/输出字典:

from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import signature_constants

with tf.Session() as sess:
  model = tf.saved_model.loader.load(export_dir=saveModelPath, sess=sess, tags=[tag_constants.SERVING]) #Note the SERVINGS tag is put as default.

  inputImage_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputImageBatch'].name
  inputImage = tf.get_default_graph().get_tensor_by_name(inputImage_name)

  inputLabel_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputClassBatch'].name
  inputLabel = tf.get_default_graph().get_tensor_by_name(inputLabel_name)

  isTraining_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['isTrainingBool'].name
  isTraining = tf.get_default_graph().get_tensor_by_name(isTraining_name)

  outputPrediction_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['predictedClassBatch'].name
  outputPrediction = tf.get_default_graph().get_tensor_by_name(outputPrediction_name)

  outPred = sess.run(outputPrediction, feed_dict={inputImage:sampleImages, isTraining:False})

  print("predicted classes:", outPred)

注意:需要默认的signature_def才能使用输入和输出字典中指定的张量名称。

答案 3 :(得分:0)

一个对我有用的代码片段,用于在单个图像上加载pb文件并进行推断。

代码遵循以下步骤:将pb文件加载到GraphDef(图形的序列化版本(用于读取pb文件)中,将GraphDef加载到Graph中,按名称获取输入和输出张量,对单张图片。

import tensorflow as tf 
import numpy as np
import cv2

INPUT_TENSOR_NAME = 'input_tensor_name:0'
OUTPUT_TENSOR_NAME = 'output_tensor_name:0'

# Read image, get shape
# Add dimension to fit batch shape
img = cv2.imread(IMAGE_PATH)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = img.astype(float)
height, width, channels = image.shape
image = np.expand_dims(image, 0)  # Add dimension (to fit batch shape)


# Read pb file into the graph as GraphDef - Serialized version of a graph     (used to read pb files)
with tf.gfile.FastGFile(PB_PATH, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# Load GraphDef into Graph
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name="")

# Get tensors (input and output) by name
input_tensor = graph.get_tensor_by_name(INPUT_TENSOR_NAME)
output_tensor = graph.get_tensor_by_name(OUTPUT_TENSOR_NAME)

# Inference on single image
with tf.Session(graph=graph) as sess:
    output_vals = sess.run(output_tensor, feed_dict={input_tensor: image})  #