将某些代码从tensorflow移植到pytorch时遇到麻烦。
因此,我有一个尺寸为10x30的矩阵,表示10个示例,每个示例具有30个特征。然后,我有另一个尺寸为10x5的矩阵,其中包含第一个矩阵中每个示例的5个最接近示例的索引。我想使用第二个矩阵中包含的索引“收集”第一个矩阵中每个示例的5个壁橱示例,使我得到一个形状为10x5x30的3d张量。
在张量流中,这是通过tf.gather(matrix1, matrix2)
完成的。有谁知道我怎么能在pytorch中做到这一点?
答案 0 :(得分:1)
怎么样?
matrix1 = torch.randn(10, 30)
matrix2 = torch.randint(high=10, size=(10, 5))
gathered = matrix1[matrix2]
它使用了对整数数组进行索引的技巧。