沿行的火炬中的散射张量

时间:2020-06-12 18:34:18

标签: python tensorflow pytorch tensor scatter

我想将张量散布在行的粒度中。

例如,

Input = torch.tensor([[2, 3], [3, 4], [4, 5]])

我想分散

S = torch.tensor([[1,2],[1,2]])

索引

I = torch.tensor([0,2])

我希望输出为torch.tensor([[1, 2], [3, 4], [1, 2]])

此处S[0]被分散到输入[I[0]],类似地S[1]被分散到Input[I[1]]

我该如何实现?我不是在遍历S中的行,而是在寻找一种更有效的方法。

2 个答案:

答案 0 :(得分:1)

执行input[I] = S

示例:

input = torch.tensor([[2, 3], [3, 4], [4, 5]])
S = torch.tensor([[1,2],[1,2]])
I = torch.tensor([0,2])
input[I] = S
input
tensor([[1, 2],
        [3, 4],
        [1, 2]])

答案 1 :(得分:0)

回答可能有点晚,但无论如何,你可以这样做:

import torch

inp = torch.tensor([[2,3], [3,4], [4,5]])
src = torch.tensor([[1,2], [1,2]])
idxs = torch.tensor([[0,0],[2,2]])

y = torch.scatter(inp, 0, idxs, src)
相关问题