根据索引映射将值从一个张量复制到另一个张量,并平均映射到相同索引的值

时间:2020-02-11 17:33:10

标签: python pytorch

我有一个输入张量和一个索引映射。我想基于该索引映射创建输出张量。

例如: 输入= torch.tensor([12,56,45,37],dtype = torch.float)

index_map = torch.tensor([-1,1,2,1,1,3,-1]) (index_map中的-1表示它没有映射到任何东西)

grad = Torch.tensor([4,2,5,5,3,6,7],dtype = torch.float)

我想根据index_map将值从grad复制到输入。映射到相同索引的值应取平均值。例如grad [1]和grad [3]都映射到索引1,因此索引1的最终值应为1.5

输出张量应类似于: 张量([12.,1.5,5.,6。])

我的代码是:

input = torch.tensor([12, 56, 45, 37], dtype=torch.float)
input = torch.cat([input, torch.tensor([-1], dtype=torch.float)])
index_map = torch.tensor([-1, 1, 2, 1, 3, -1])
grad_out = torch.tensor([4, 2, 5, 3, 6, 7], dtype=torch.float)
input[index_map] = grad_out
input = input[:4]
print(input)

以上代码会复制最新值,但不会将其取平均值。

0 个答案:

没有答案