如何使用索引有效地获取张量中每一行的值?

时间:2019-05-31 23:41:54

标签: python tensorflow slice z-index tensor

我有一个名为my_tensor的张量,tha形状为[batch_size, seq_length],而我有另一个名为idx的张量,tha形状为[batch_size, 1],它由从0开始到“ seq_length”结束的索引组成。

我想使用idx中定义的索引提取my_tensor每行中的值。

我尝试使用tf.gather_ndtf.gather,但没有成功。

考虑以下示例:

batch_size = 3
seq_length = 5
idx = [2, 0, 4]

my_tensor = tf.random.uniform(shape=(batch_size, seq_length))

我想在

处获取值
[[0, 2],
 [1, 0],
 [3, 4]]

来自my_tensor。

我必须对它们进行进一步的处理,因此我想同时高效地拥有它们(我什至不知道是否可能);但是,我无法提出其他任何方法。

感谢您的帮助:)

1 个答案:

答案 0 :(得分:0)

诀窍是首先将一组索引转换为布尔掩码,然后可以使用它们boolean_mask来减少my_tensor

您可以通过one-hot encoding idx张量来实现。

因此,在idx = [2, 0, 4]处,我们可以做tf.one_hot(idx, seq_length)以便将其转换为如下形式:

[ [0., 0., 1., 0., 0.],
  [1., 0., 0., 0., 0.],
  [0., 0., 0., 0., 1.] ]

然后,将所有内容放在一起,例如my_tensor

[ [0.6413697 , 0.4079175 , 0.42499018, 0.3037368 , 0.8580252 ],
  [0.8698617 , 0.29096508, 0.11531639, 0.25421357, 0.5844104 ],
  [0.6442119 , 0.31816053, 0.6245482 , 0.7249261 , 0.7595779 ] ]

我们可以进行以下操作:

result = tf.boolean_mask(my_tensor, tf.one_hot(idx,seq_length))

给出:

[0.42499018, 0.8698617 , 0.7595779 ]

符合预期