我有一个具有概率的张量。这是一个动态张量,形状为(?,30),我选择这30个值中概率最高的索引为:
best_probability = tf.argmax(probability, axis = 1)
现在张量best_probability的维数是(?,)。现在我想从另一个张量中选择具有这些索引的值,该张量称为具有维度(?,30,1024,3)的数据。本质上,使用best_probability张量从30个值中的每一个中选择一个概率最高。
最终输出的尺寸应为(?,1024,3)。
PS:-我尝试了collect_nd,但是它需要为best_probability张量索引,例如[[0,9],[1,10],[2,15],[3,25]]。为此,我编写了以下代码段。
selected_data = tf.stack(tf.range(probability.shape[0]),
tf.argmax(probability, axis = 1))
这不起作用,因为我正在处理动态张量。是否有解决此问题的替代方法。
答案 0 :(得分:1)
我能够使用tf.batch_gather和tf.reshape解决此问题
selected_data = tf.reshape(tf.batch_gather(data, best_probability),
(-1, data.shape[2],data.shape[3]))