使用tensorflow通过索引张量从值张量获取值

时间:2020-09-25 13:06:49

标签: tensorflow

有一个这样的索引张量:[[1,2,3],[1,2,3]](形状是批*长度)

有一个这样的值张量:(形状是批*长度*深)

[[[0.9,0.9,0.1,0.1],[0.9,0.1,0.8,0.1],[0.9,0.1,0.1,0.6]],
[[0.1,0.9,0.8,1],[1,2,0.8,0.1],[0.1,0.1,2,0.6]]]. 

我如何通过张量流获得[[0.9,0.8,0.6],[0.9,0.8,0.6]]

1 个答案:

答案 0 :(得分:2)

我不确定这是否是最好的解决方案,但是它可以工作:tf.gather_nd(values, tf.expand_dims(index, -1), batch_dims=2)

例如:

>>> index = tf.constant([[1,2,3],[1,2,3]])
>>> values = tf.constant([[[0.9,0.9,0.1,0.1],[0.9,0.1,0.8,0.1],[0.9,0.1,0.1,0.6]],[[0.1,0.9,0.8,1],[1,2,0.8,0.1],[0.1,0.1,2,0.6]]])
>>> result = tf.gather_nd(values, tf.expand_dims(index, -1), batch_dims=2)
>>> result.eval()
array([[0.9, 0.8, 0.6],
       [0.9, 0.8, 0.6]], dtype=float32)