沿张量的第二维收集元素

时间:2017-02-11 12:03:37

标签: python tensorflow

假设values和张量T都具有形状(N,K)。现在,如果我们根据矩阵来考虑它们,我希望T的每一行都能获得与values具有最大值的索引相对应的行元素。我可以用

轻松找到这些指数
max_indicies = tf.argmax(T, 1)

返回形状(N)的张量。现在,我如何从T收集这些索引,以便得到形状N?我试过了

result = tf.gather(T,max_indices)

但它没有做正确的事情 - 它会返回一些形状(N,K),这意味着它没有收集任何东西。

1 个答案:

答案 0 :(得分:2)

您可以使用tf.gather_nd

例如,

import tensorflow as tf

sess = tf.InteractiveSession()

values = tf.constant([[0, 0, 0, 1],
                      [0, 1, 0, 0],
                      [0, 0, 1, 0]])

T = tf.constant([[0, 1, 2 ,  3],
                 [4, 5, 6 ,  7],
                 [8, 9, 10, 11]])

max_indices = tf.argmax(values, axis=1)
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0], 
                                            dtype=max_indices.dtype),
                                   max_indices),
                                  axis=1))

print(result.eval())

但是,当valuesT的排名较高时,使用tf.gather_nd会有点尴尬。我在this question上发布了我当前的解决方案。对于高维valuesT,可能会有更好的解决方案。