所以事情就是这样:我试图从已被冻结到.pb(ProtoBuf)文件的模型中使用推断。
我已经正确地冻结了模型,选择了我有兴趣用于推理的节点(只是输出)。我也可以选择输出张量但是当我输入张量时它会给我一个错误:
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'w2' with dtype float
[[Node: w2 = Placeholder[dtype=DT_FLOAT, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
这是一个我冻结的简单模型:
import tensorflow as tf
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
print(sess.run(w4, feed_dict))
# Prints 24 which is sum of (w1+w2)*b1
saver.save(sess, 'my_test_model/test', global_step=1000)
以下是我用来进行推理的代码(来自.pb文件):
w1 = tf.placeholder("float")
w2 = tf.placeholder("float")
with tf.Session() as sess:
init = tf.global_variables_initializer()
with tf.gfile.FastGFile("my_test_model/frozen_model.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
tensor = sess.graph.get_tensor_by_name('op_to_restore:0')
# sess.run(init)
print(tensor)
predictions = sess.run(tensor, feed_dict={w1: 4, w2: 8})
print(predictions)
任何帮助都会很有价值,谢谢!
答案 0 :(得分:0)
只需对这个问题做出明确的回答:
如果有人遇到此问题。.对我有用的修复方法是将行feed_dict={w1: 4, w2: 8}
更改为feed_dict={'w1:0': 4, 'w2:0': 8}
,因为已经创建了此节点。如果要打印图形的节点,则得到它们的线是:
[n.name for n in tf.get_default_graph().as_graph_def().node]