假设我有一个索引列表,并希望使用此列表修改现有数组。目前,我唯一可以做到这一点的方法是使用如下所示的for循环。只是想知道是否有更快/有效的方法。
torch.manual_seed(0)
a = torch.randn(5,3)
idx = torch.Tensor([[1,2], [3,2]], dtype=torch.long)
for i,j in idx:
a[i,j] = 1
我最初认为gather
或index_select
可以回答这个问题,但是从documentation看,这似乎不是答案。
在我的特定情况下,a是5维向量,而idx是Nx5向量。因此,我期望的输出(在用a[idx]
下标之后)是(N,)
形的向量。
由于下面的@shai,我正在寻找的答案是:
a[idx.t().chunk(chunks=2,dim=0)]
。取自此SO answer。