无法使用simple_save张量流导出保存的模型

时间:2019-04-17 06:19:59

标签: tensorflow

我正在尝试对tensorflow使用simple_save,但是它不起作用:(

这是我的代码:

def export_model(saved_model_dir, final_tensor_name):
 with tf.Session() as sess:
 with sess.graph.as_default() as graph:
 tf.saved_model.simple_save(
  sess,
  saved_model_dir,
  inputs={'image': tf.placeholder(tf.float32)},
  outputs={'prediction': graph.get_tensor_by_name(final_tensor_name + ":0")}
 )

我收到以下错误:

tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value final_training_ops/biases/final_biases
 [[{{node save/SaveV2}}]]

我正在使用以下教程:https://github.com/BartyzalRadek/Multi-label-Inception-net

我花了很多时间试图在线查找解决方案,我知道这并不难。我已经有一个要导出的图形,现在我需要的只是savedmodel.pb。任何帮助表示赞赏!谢谢!

新更新-下面的代码

def export_model(saved_model_dir, final_tensor_name):
  with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    with sess.graph.as_default() as graph:
      tf.saved_model.simple_save(
        sess,
        saved_model_dir,
        inputs={'image': tf.placeholder(tf.string)},
        outputs={'prediction': graph.get_tensor_by_name(final_tensor_name + ":0")}
       )

代码现在可以运行,但是当我测试保存的模型时,我总是得到相同的结果。

IMAGE_LABELING_CODE

import tensorflow as tf
import sys

image_path = sys.argv[1]

image_data = tf.gfile.FastGFile(image_path, 'rb').read()

label_lines = [line.rstrip() for line 
               in tf.gfile.GFile("labels.txt")]

with tf.gfile.FastGFile("retrained_graph.pb", 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')

    predictions = sess.run(softmax_tensor, \
         {'DecodeJpeg/contents:0': image_data})

1 个答案:

答案 0 :(得分:0)

就像@giser_yugang所说的,也许您应该在图表的构建部分的末尾:init = tf.global_variables_initializer(),然后在执行会话之后,在执行sess.run(init)

尽管如此,如果它是一个局部变量,则必须将变量添加到某个集合中,建立初始化程序并运行它。例如:

a = tf.Variable(..., collections=[tf.GRAPH_KEYS.LOCAL_VARIABLES])
local_init = tf.local_variable_initializer()
...

with tf.Session() as sess:
    sess.run(local_init)

尽管如此,tensorflow库中的某些实现直接进入局部变量,例如tf.metrics(如果尚未更改),则只需定义并运行local_init = tf.local_variables_initializer()和{{1} }