更新PyTorch中的特定矢量元素

时间:2018-11-15 14:39:08

标签: python numpy pytorch autograd

我有一个较大的向量要更新。我将通过向向量中的特定元素添加偏移量来更新它。我指定了要更新的索引向量(称为索引向量ix),并为每个索引指定了要添加到该元素的值(称为值向量vals )。如果索引向量的所有条目都是唯一的,那么下面的代码就足够了:

vec = torch.zeros(4, dtype=torch.float)
ix = torch.tensor([0,2], dtype=torch.long)
vals = torch.tensor([0.2, 0.5], dtype=torch.float)
vec[ix] += vals

但是,如果ix中有重复的索引,则此方法不起作用。对于重复索引的情况,幼稚的方法如下:

for i in range(len(ix)):
    vec[ix[i]] += vals[i]

但这不能很好地扩展-ix很大时,它的速度很慢。有没有更快的方法可以做到这一点?如果有一种快速的方法来汇总vals中具有相同索引的ix的所有条目,那么解决方案应该很简单。

更新
我找到了一种效果很好的解决方案,如下所述。我仍然很乐意提供更好的解决方案。

# get unique indices
ix_unique = torch.unique(ix)

# for each unique index, get sum of all vals with that index
vals_unique = torch.stack([
    torch.sum(torch.where(ix==i, vals, torch.zeros_like(vals))) 
    for i in ix_unique
])

# update vec
vec[ix_unique] += vals_unique

2 个答案:

答案 0 :(得分:0)

对于要允许对同一ix索引进行多次更新的情况,还存在一个名为pytorch_scatter的库。 在这种情况下, ix中的3个相同的索引将导致将3 * val添加到该索引。

答案 1 :(得分:0)

torch.index_add()

import torch

vec = torch.zeros(4, dtype=torch.float)
ix = torch.tensor([0,0,2], dtype=torch.long)
vals = torch.tensor([0.2,0.1,0.5], dtype=torch.float)
torch.index_add(vec, 0, ix, vals)

您将得到

tensor([0.3000, 0.0000, 0.5000, 0.0000])

参考:official doc