我有一个这样的张量:
x = torch.tensor([[3, 4, 2], [0, 1, 5]])
而且我有这样的索引:
ind = torch.tensor([[1, 1, 0], [0, 0, 1]])
然后我想通过x
和ind
生成一个新的张量:
z = torch.tensor([0, 1, 2], [3, 4, 5])
我用python这样实现:
# -*- coding: utf-8 -*-
import torch
x = torch.tensor([[3, 4, 2], [0, 1, 5]])
ind = torch.tensor([[1, 1, 0], [0, 0, 1]])
z = torch.zeros_like(x)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
z[i, j] = x[ind[i][j]][j]
print(z)
我想知道如何通过pytorch
解决此问题吗?
答案 0 :(得分:1)
您正在寻找torch.gather
In [1]: import torch
In [2]: x = torch.tensor([[3, 4, 2], [0, 1, 5]])
In [3]: ind = torch.tensor([[1, 1, 0], [0, 0, 1]])
In [4]: torch.gather(x, 0, ind)
Out[4]:
tensor([[0, 1, 2],
[3, 4, 5]])