基于索引向量的张量约简

时间:2019-04-25 17:30:39

标签: python machine-learning pytorch

作为一个例子,我有2个张量:A = [1;2;3;4;5;6;7]B = [2;3;2]。我的想法是我想基于B来减少A-这样B的值代表如何求和A的值-这样B = [2;3;2]意味着减少的A应该是前两个值,下一个3和最后一个值的总和2:A' = [(1+2);(3+4+5);(6+7)]。显然,B的总和应始终等于A的长度。我试图尽可能有效地做到这一点-最好是pytorch / python中包含的特定函数或矩阵运算。谢谢!

1 个答案:

答案 0 :(得分:0)

这是解决方案。

  • 首先,我们创建一个索引数组B_idx,其大小为A
  • 然后,使用A根据索引B_idx累积(添加)index_add_中的所有元素。
A = torch.arange(1, 8) 
B = torch.tensor([2, 3, 2])

B_idx = [idx.repeat(times) for idx, times in zip(torch.arange(len(B)), B)]
B_idx = torch.cat(B_idx) # tensor([0, 0, 1, 1, 1, 2, 2])

A_sum = torch.zeros_like(B)
A_sum.index_add_(dim=0, index=B_idx, source=A)
print(A_sum) # tensor([ 3, 12, 13])