我有一个输入张量和一个索引映射。我想基于该索引映射创建输出张量。
例如: 输入= torch.tensor([12,56,45,37],dtype = torch.float)
index_map = torch.tensor([-1,1,2,1,1,3,-1]) (index_map中的-1表示它没有映射到任何东西)
grad = Torch.tensor([4,2,5,5,3,6,7],dtype = torch.float)
我想根据index_map将值从grad复制到输入。映射到相同索引的值应取平均值。例如grad [1]和grad [3]都映射到索引1,因此索引1的最终值应为1.5
输出张量应类似于: 张量([12.,1.5,5.,6。])
我的代码是:
input = torch.tensor([12, 56, 45, 37], dtype=torch.float)
input = torch.cat([input, torch.tensor([-1], dtype=torch.float)])
index_map = torch.tensor([-1, 1, 2, 1, 3, -1])
grad_out = torch.tensor([4, 2, 5, 3, 6, 7], dtype=torch.float)
input[index_map] = grad_out
input = input[:4]
print(input)
以上代码会复制最新值,但不会将其取平均值。