参数维度对聚集函数的影响

时间:2017-10-18 02:27:45

标签: torch

我正在尝试使用pytorch中的gather函数,但无法理解dim参数的作用。

代码:

t = torch.Tensor([[1,2],[3,4]])
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]])))

输出:

 1  2
 3  2
[torch.FloatTensor of size 2x2]

尺寸设为1:

print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])))

输出变为:

 1  1
 4  3
[torch.FloatTensor of size 2x2]

如何,gather功能确实有效?

2 个答案:

答案 0 :(得分:4)

我意识到聚集功能是如何工作的。

t = torch.Tensor([[1,2],[3,4]])
index = torch.LongTensor([[0,0],[1,0]])
torch.gather(t, 0, index)

由于dimension为零,因此输出为:

| t[index[0, 0], 0]   t[index[0, 1], 1] |
| t[index[1, 0], 0]   t[index[1, 1], 1] |

如果dimension设置为1,则输出将变为:

| t[0, index[0, 0]]   t[0, index[0, 1]] |
| t[1, index[1, 0]]   t[1, index[1, 1]] |

所以公式是:

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

参考:http://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather

答案 1 :(得分:2)

只需添加到现有答案中,gather的一个应用就是按指定维度收集分数。

例如我们有这样的设置:

  • 3个班级和5个例子
  • 为每个班级分配一个分数,为每个例子执行
  • 目标是收集标签y
  • 指示的分数

代码如下

torch.manual_seed(0)

num_examples = 5
num_classes = 3
scores = torch.randn(5, 3)

#print of scores
scores: tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820],
        [-0.8567,  1.1006, -1.0712]])


y = torch.LongTensor([1, 2, 1, 0, 2])
res = scores.gather(1, y.view(-1, 1)).squeeze()

输出:

#print of gather results
tensor([-0.2934, -1.3986,  0.8380, -0.4033, -1.0712])