pytorch用索引列表修改数组

时间:2018-10-22 04:23:37

标签: python pytorch

假设我有一个索引列表,并希望使用此列表修改现有数组。目前,我唯一可以做到这一点的方法是使用如下所示的for循环。只是想知道是否有更快/有效的方法。

torch.manual_seed(0)
a = torch.randn(5,3)
idx = torch.Tensor([[1,2], [3,2]], dtype=torch.long)
for i,j in idx:
    a[i,j] = 1

我最初认为gatherindex_select可以回答这个问题,但是从documentation看,这似乎不是答案。

在我的特定情况下,a是5维向量,而idx是Nx5向量。因此,我期望的输出(在用a[idx]下标之后)是(N,)形的向量。

答案

由于下面的@shai,我正在寻找的答案是: a[idx.t().chunk(chunks=2,dim=0)]。取自此SO answer

1 个答案:

答案 0 :(得分:1)

这很简单

a[idx[:,0], idx[:,1]] = 1

您可以在this thread中找到更通用的解决方案。