我有一个已保存的Tensorflow图表,它通过placeholder
使用feed_dict
参数来消耗输入。
sess.run(my_tensor, feed_dict={input_image: image})
由于使用Dataset
Iterator
提供数据为more efficient,我想加载已保存的图表,将input_image
placeholder
替换为{{1}并运行。我怎样才能做到这一点?有没有更好的方法呢?代码示例的答案将受到高度赞赏。
答案 0 :(得分:6)
您可以通过序列化图表并使用tf.import_graph_def
重新导入它来实现这一目标,input_map
具有x
参数,用于在所需位置插入输入。
要做到这一点,您至少需要知道您替换的输入的名称以及您希望执行的输出(在我的示例中分别为y
和import tensorflow as tf
# restore graph (built from scratch here for the example)
x = tf.placeholder(tf.int64, shape=(), name='x')
y = tf.square(x, name='y')
# just for display -- you don't need to create a Session for serialization
with tf.Session() as sess:
print("with placeholder:")
for i in range(10):
print(sess.run(y, {x: i}))
# serialize the graph
graph_def = tf.get_default_graph().as_graph_def()
tf.reset_default_graph()
# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'x:0': batch}, return_elements=['y:0'])
# enjoy Dataset inputs!
with tf.Session() as sess:
print('with Dataset:')
try:
while True:
print(sess.run(y))
except tf.errors.OutOfRangeError:
pass
。)
graph_def
请注意,占位符节点仍然存在,因为我没有在这里解析GraphDef
来删除它 - 你可以删除它作为改进,虽然我认为将它留在这里也没关系。 / p>
根据您恢复图形的方式,输入替换可能已经内置在加载程序中,这使事情变得更简单(无需返回.meta
)。例如,如果您从tf.train.import_meta_graph
文件加载图表,则可以使用input_map
接受相同的import tensorflow as tf
# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# load your net and plug in new pipeline
# you need to know the name of the tensor where to plug-in your input
restorer = tf.train.import_meta_graph(graph_filepath, input_map={'x:0': batch})
y = tf.get_default_graph().get_tensor_by_name('y:0')
# enjoy Dataset inputs!
with tf.Session() as sess:
# not needed here, but in practice you would also need to restore weights
# restorer.restore(sess, weights_filepath)
print('with Dataset:')
try:
while True:
print(sess.run(y))
except tf.errors.OutOfRangeError:
pass
参数。
{{1}}