使用Tensorflow的top_k和scatter_nd

时间:2018-04-22 13:25:33

标签: python tensorflow

我正在尝试在tensorflow中编写一个只传播每个要素图的前k个值的操作。

示例:

k = 1,输入大小为[batch_size, x, y, channels]让我们说它是[1,2,2,3]

输出应该是相同的大小,如果k = 1比每个x,y平面只有一个非零。

numpy中的例子:

input = [[[[6.4 1.4 1.3] [2.1  6.5  4.8]][[2.3 9.2  2.8][7.9  5.1 0.6]]]]]
输出应该是:

[[[[6.4 0. 0.] [0. 6.5 0.]]  [[0. 9.2 0.] [7.9 0. 0.]]]]

为了在tensorflow中执行此操作,我想使用nn.top_k,然后使用scatter_nd。

问题是top_k非常不同地从scatter_nd需要的方式返回所请求元素的索引。

top_k返回形状[[[[0],[1]], [[1],[0]]]]

中的索引数组(1,2,2,1)

scatter_nd需要它作为每个值的所有坐标的列表,如下所示:

[[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1], [0, 1, 1, 0]]

有没有人知道如何在它们之间进行转换?或者甚至可能为这次行动而不同的approch?

1 个答案:

答案 0 :(得分:2)

tf.nn.top_k()仅返回最后一维中的前k个值。所以你必须添加所有其他维度。最简单的tf.where()。代码(已测试):

import tensorflow as tf

inp = tf.constant( [ [ [ [6.4, 1.4, 1.3], [2.1,  6.5,  4.8] ], [ [2.3, 9.2, 2.8], [7.9, 5.1, 10.6] ] ] ] )

t, idx = tf.nn.top_k( inp, k = 2 )
idx_one_hot = tf.one_hot( idx, depth = 3 )
idx_red = tf.reduce_sum( idx_one_hot, axis = -2 )
idx2 = tf.where( tf.not_equal( idx_red, 0 ) )

with tf.Session() as sess:
    print( sess.run( idx2 ) )

输出(注意我已经改变了你的例子中的最后一个数字,索引也是2,只有0和1看起来有点误导,好像它是一个布尔张量):

  

[[0 0 0 0]
   [0 0 0 1]
   [0 0 1 1]
   [0 0 1 2]
   [0 1 0 1]
   [0 1 0 2]
   [0 1 1 0]
   [0 1 1 2]]

请注意,这会丢失top_k报告的最后一个维度中的索引顺序,它会将其更改为索引本身的递增顺序。