如何使用张量流中的给定索引切片张量?

时间:2018-11-10 13:17:14

标签: tensorflow slice

我有一个具有概率的张量。这是一个动态张量,形状为(?,30),我选择这30个值中概率最高的索引为:

    best_probability = tf.argmax(probability, axis = 1)

现在张量best_probability的维数是(?,)。现在我想从另一个张量中选择具有这些索引的值,该张量称为具有维度(?,30,1024,3)的数据。本质上,使用best_probability张量从30个值中的每一个中选择一个概率最高。

最终输出的尺寸应为(?,1024,3)。

PS:-我尝试了collect_nd,但是它需要为best_probability张量索引,例如[[0,9],[1,10],[2,15],[3,25]]。为此,我编写了以下代码段。

 selected_data = tf.stack(tf.range(probability.shape[0]),
                                tf.argmax(probability, axis = 1))

这不起作用,因为我正在处理动态张量。是否有解决此问题的替代方法。

1 个答案:

答案 0 :(得分:1)

我能够使用tf.batch_gather和tf.reshape解决此问题

selected_data = tf.reshape(tf.batch_gather(data, best_probability),
                           (-1, data.shape[2],data.shape[3]))