我有一个名为my_tensor的张量,tha形状为[batch_size, seq_length]
,而我有另一个名为idx的张量,tha形状为[batch_size, 1]
,它由从0开始到“ seq_length”结束的索引组成。
我想使用idx中定义的索引提取my_tensor每行中的值。
我尝试使用tf.gather_nd
和tf.gather
,但没有成功。
考虑以下示例:
batch_size = 3
seq_length = 5
idx = [2, 0, 4]
my_tensor = tf.random.uniform(shape=(batch_size, seq_length))
我想在
处获取值[[0, 2],
[1, 0],
[3, 4]]
来自my_tensor。
我必须对它们进行进一步的处理,因此我想同时高效地拥有它们(我什至不知道是否可能);但是,我无法提出其他任何方法。
感谢您的帮助:)
答案 0 :(得分:0)
诀窍是首先将一组索引转换为布尔掩码,然后可以使用它们boolean_mask来减少my_tensor
。
您可以通过one-hot encoding idx
张量来实现。
因此,在idx = [2, 0, 4]
处,我们可以做tf.one_hot(idx, seq_length)
以便将其转换为如下形式:
[ [0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1.] ]
然后,将所有内容放在一起,例如my_tensor
:
[ [0.6413697 , 0.4079175 , 0.42499018, 0.3037368 , 0.8580252 ],
[0.8698617 , 0.29096508, 0.11531639, 0.25421357, 0.5844104 ],
[0.6442119 , 0.31816053, 0.6245482 , 0.7249261 , 0.7595779 ] ]
我们可以进行以下操作:
result = tf.boolean_mask(my_tensor, tf.one_hot(idx,seq_length))
给出:
[0.42499018, 0.8698617 , 0.7595779 ]
符合预期