在Tensorflow中,如何解析由tf.nn.max_pool_with_argmax获得的扁平指数?

时间:2016-06-07 14:32:08

标签: python tensorflow deep-learning

我遇到了一个问题:在我使用tf.nn.max_pool_with_argmax之后,我获得了索引,即 argmax: A Tensor of type Targmax. 4-D. The flattened indices of the max values chosen for each output.

如何将展平的索引解开回Tensorflow中的坐标列表?

非常感谢。

1 个答案:

答案 0 :(得分:2)

我今天遇到了同样的问题,我最终得到了这个解决方案:

def unravel_argmax(argmax, shape):
    output_list = []
    output_list.append(argmax // (shape[2] * shape[3]))
    output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
    return tf.pack(output_list)

这是ipython笔记本中的usage example(我使用它将汇集的argmax位置转发到我的解放方法)