假设我有一个Tensor of shape(100,20)。现在我还有一个形状指数的张量(100,)。如何获得一个形状张量(100,)或(100,1)每行(100行)正确的值(由指数中的相应索引选择?
小例子: 所以我们说张量A是
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
和张量B是
[0,2,1]
然后我想要输出
[1,6,8]
答案 0 :(得分:4)
您可以使用适当的范围加入B张量来创建二维索引(在您的示例中为[[0, 0], [1, 2], [2, 1]]
),然后使用tf.gather_nd
提取元素:
b_2 = tf.expand_dims(b, 1)
range = tf.expand_dims(tf.range(tf.shape(b)[0]), 1)
ind = tf.concat(1, [range, b_2])
res = tf.gather_nd(a, ind)