如何在TensorFlow的预训练模型中获得maxpool图层的输出?

时间:2017-06-07 13:16:39

标签: tensorflow deep-learning

我有一个训练有素的模特。我希望从模型中提取中间maxpool图层的输出。 我尝试了以下

saver = tf.train.import_meta_graph(BASE_DIR + LOG_DIR + '/model.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint(BASE_DIR + LOG_DIR))
sess.run("maxpool/maxpool",feed_dict=feed_dict)

这里,feed_dict包含占位符及其在字典中运行的内容。

我一直收到以下错误

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_1_1' with dtype float and shape...

这可能是什么原因?我生成了所有占位符并在feed字典中输入它们。

1 个答案:

答案 0 :(得分:1)

我遇到了类似的问题而且令人沮丧。让我解决的是为我稍后想要调用的每个变量和操作填写name字段。您还可能需要将maxpool/maxpool操作添加到tf.add_to_collection('name_for_maxpool_op', maxpool_op_handle)的集合中。然后,您可以使用以下命令恢复操作和命名张量:

# Restore from metagraph.
saver = tf.train.import_meta_graph(...)
sess = tf.Session()
saver = restore(sess, ...)
graph = sess.graph

# Restore your ops and tensors.
maxpool_op = tf.get_collection('name_for_maxpool_op')[0]  # returns a list, you want the first element
a_tensor = graph.get_tensor_by_name('tensor_name:0')  # need the :0 added to your name

然后,您将使用恢复的张量构建feed_dictMore information can be found here。另外,正如您在评论中提到的,您需要将操作本身传递给sess.run,而不是它的名称:

sess.run(maxpool_op, feed_dict=feed_dict)

即使你没有命名它们,你也可以从恢复的元图中访问你的张量和操作(例如,为了避免用新的花式张量名称重新训练模型),但这可能有点痛苦。自动赋予张量的名称并不总是最透明的。您可以使用以下命令列出图表中所有变量的名称:

print([v.name for v in tf.all_variables()])

您可以希望找到您要在那里找到的名称,然后使用graph.get_tensor_by_name恢复该张量,如上所述。