如何替换已保存图表的输入,例如数据集迭代器的占位符?

时间:2018-05-16 07:04:35

标签: python tensorflow tensorflow-datasets

我有一个已保存的Tensorflow图表,它通过placeholder使用feed_dict参数来消耗输入。

sess.run(my_tensor, feed_dict={input_image: image})

由于使用Dataset Iterator提供数据为more efficient,我想加载已保存的图表,将input_image placeholder替换为{{1}并运行。我怎样才能做到这一点?有没有更好的方法呢?代码示例的答案将受到高度赞赏。

1 个答案:

答案 0 :(得分:6)

您可以通过序列化图表并使用tf.import_graph_def重新导入它来实现这一目标,input_map具有x参数,用于在所需位置插入输入。

要做到这一点,您至少需要知道您替换的输入的名称以及您希望执行的输出(在我的示例中分别为yimport tensorflow as tf # restore graph (built from scratch here for the example) x = tf.placeholder(tf.int64, shape=(), name='x') y = tf.square(x, name='y') # just for display -- you don't need to create a Session for serialization with tf.Session() as sess: print("with placeholder:") for i in range(10): print(sess.run(y, {x: i})) # serialize the graph graph_def = tf.get_default_graph().as_graph_def() tf.reset_default_graph() # build new pipeline batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next() # plug in new pipeline [y] = tf.import_graph_def(graph_def, input_map={'x:0': batch}, return_elements=['y:0']) # enjoy Dataset inputs! with tf.Session() as sess: print('with Dataset:') try: while True: print(sess.run(y)) except tf.errors.OutOfRangeError: pass 。)

graph_def

请注意,占位符节点仍然存在,因为我没有在这里解析GraphDef来删除它 - 你可以删除它作为改进,虽然我认为将它留在这里也没关系。 / p>

根据您恢复图形的方式,输入替换可能已经内置在加载程序中,这使事情变得更简单(无需返回.meta)。例如,如果您从tf.train.import_meta_graph文件加载图表,则可以使用input_map接受相同的import tensorflow as tf # build new pipeline batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next() # load your net and plug in new pipeline # you need to know the name of the tensor where to plug-in your input restorer = tf.train.import_meta_graph(graph_filepath, input_map={'x:0': batch}) y = tf.get_default_graph().get_tensor_by_name('y:0') # enjoy Dataset inputs! with tf.Session() as sess: # not needed here, but in practice you would also need to restore weights # restorer.restore(sess, weights_filepath) print('with Dataset:') try: while True: print(sess.run(y)) except tf.errors.OutOfRangeError: pass 参数。

{{1}}