pytorch等效tf.gather

时间:2018-12-09 23:05:46

标签: tensorflow pytorch

将某些代码从tensorflow移植到pytorch时遇到麻烦。

因此,我有一个尺寸为10x30的矩阵,表示10个示例,每个示例具有30个特征。然后,我有另一个尺寸为10x5的矩阵,其中包含第一个矩阵中每个示例的5个最接近示例的索引。我想使用第二个矩阵中包含的索引“收集”第一个矩阵中每个示例的5个壁橱示例,使我得到一个形状为10x5x30的3d张量。

在张量流中,这是通过tf.gather(matrix1, matrix2)完成的。有谁知道我怎么能在pytorch中做到这一点?

1 个答案:

答案 0 :(得分:1)

怎么样?

matrix1 = torch.randn(10, 30)
matrix2 = torch.randint(high=10, size=(10, 5))
gathered = matrix1[matrix2]

它使用了对整数数组进行索引的技巧。