为什么dim = 1返回torch.argmax中的行索引?

时间:2019-04-15 14:50:47

标签: python matrix pytorch tensor argmax

我正在使用PyTorch的{​​{1}}函数,该函数定义为:

argmax

考虑示例

torch.argmax(input, dim=None, keepdim=False)

在这里,当我使用dim = 1而不是搜索列向量时,该函数将搜索行向量,如下所示。

a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))

就我的假设而言,dim = 0代表行,dim = 1代表列。

1 个答案:

答案 0 :(得分:1)

是时候正确理解轴参数如何工作了:

tensor dimension

理解上面的图片之后,

    |
    v
  dim-0  ---> -----> dim-1 ------> -----> --------> dim-1
    |   [[-1.7739,  0.8073,  0.0472, -0.4084],
    v    [ 0.6378,  0.6575, -1.2970, -0.0625],
    |    [ 1.7970, -1.3463,  0.9011, -0.8704],
    v    [ 1.5639,  0.7123,  0.0385,  1.8410]]
    |
    v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])