有一个这样的索引张量:[[1,2,3],[1,2,3]]
(形状是批*长度)
有一个这样的值张量:(形状是批*长度*深)
[[[0.9,0.9,0.1,0.1],[0.9,0.1,0.8,0.1],[0.9,0.1,0.1,0.6]],
[[0.1,0.9,0.8,1],[1,2,0.8,0.1],[0.1,0.1,2,0.6]]].
我如何通过张量流获得[[0.9,0.8,0.6],[0.9,0.8,0.6]]
?
答案 0 :(得分:2)
我不确定这是否是最好的解决方案,但是它可以工作:tf.gather_nd(values, tf.expand_dims(index, -1), batch_dims=2)
例如:
>>> index = tf.constant([[1,2,3],[1,2,3]])
>>> values = tf.constant([[[0.9,0.9,0.1,0.1],[0.9,0.1,0.8,0.1],[0.9,0.1,0.1,0.6]],[[0.1,0.9,0.8,1],[1,2,0.8,0.1],[0.1,0.1,2,0.6]]])
>>> result = tf.gather_nd(values, tf.expand_dims(index, -1), batch_dims=2)
>>> result.eval()
array([[0.9, 0.8, 0.6],
[0.9, 0.8, 0.6]], dtype=float32)