PyTorch 在具有退化索引的索引处添加张量

时间:2021-06-02 22:12:27

标签: python pytorch

这个问题可以看作是对this one的扩展。

我有两个一维张量,countsidx。 Counts 的长度为 20,并存储落入 20 个 bin 中的 1 个的事件的发生次数。 idx 很长,每一项都是一个整数,对应 20 个事件中的 1 个发生,每个事件可以发生多次。我想要一种矢量化或非常快速的方法来将 i 中事件 idx 发生的次数添加到 i 中的第 counts 个存储桶。此外,如果该解决方案与训练循环期间对 countidx 的批处理兼容,那将是理想的。

我的第一个想法是简单地使用这种索引 countsidx 的策略:

counts = torch.zeros(5)

idx = torch.tensor([1,1,1,2,3])

counts[idx] += 1

但它不起作用,计数结束于

tensor([0., 1., 1., 1., 0.])

而不是想要的

tensor([0., 3., 1., 1., 0.])

我能做到这一点的最快方法是什么?我的下一个最佳猜测是

for i in range(20):
   counts[i] += idx[idx == i].sum()

1 个答案:

答案 0 :(得分:0)

请考虑使用 bincount 函数实现的以下提案,该函数计算 非负 整数张量中每个值的频率(唯一约束)。

import torch

EVENT_TYPES = 20
counts = torch.zeros(EVENT_TYPES)
events = torch.tensor([1, 1, 1, 2, 3, 9])
batch_counts = torch.bincount(events, minlength=EVENT_TYPES)

print(counts + batch_counts)

结果:

tensor([0., 3., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0.])

您可以仅在火炬张量环境中对每个批次进行评估。您可以使用 bincount 函数中的 minlength 参数控制事件类型的数量。在本例中为 20,如您在问题中所述。