找到protobuf模型的输入张量

时间:2018-04-23 10:28:25

标签: tensorflow

所以事情就是这样:我试图从已被冻结到.pb(ProtoBuf)文件的模型中使用推断。

我已经正确地冻结了模型,选择了我有兴趣用于推理的节点(只是输出)。我也可以选择输出张量但是当我输入张量时它会给我一个错误:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'w2' with dtype float
 [[Node: w2 = Placeholder[dtype=DT_FLOAT, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

这是一个我冻结的简单模型:

import tensorflow as tf

w1 = tf.placeholder("float", name="w1")

w2 = tf.placeholder("float", name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

print(sess.run(w4, feed_dict))
# Prints 24 which is sum of (w1+w2)*b1

saver.save(sess, 'my_test_model/test', global_step=1000)

以下是我用来进行推理的代码(来自.pb文件):

w1 = tf.placeholder("float")
w2 = tf.placeholder("float")
with tf.Session() as sess:
    init = tf.global_variables_initializer()

    with tf.gfile.FastGFile("my_test_model/frozen_model.pb", 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')
    tensor = sess.graph.get_tensor_by_name('op_to_restore:0')
    # sess.run(init)
    print(tensor)
    predictions = sess.run(tensor, feed_dict={w1: 4, w2: 8})

print(predictions)

任何帮助都会很有价值,谢谢!

1 个答案:

答案 0 :(得分:0)

只需对这个问题做出明确的回答:

如果有人遇到此问题。.对我有用的修复方法是将行feed_dict={w1: 4, w2: 8}更改为feed_dict={'w1:0': 4, 'w2:0': 8},因为已经创建了此节点。如果要打印图形的节点,则得到它们的线是:

[n.name for n in tf.get_default_graph().as_graph_def().node]