我正在使用PyTorch的{{1}}函数,该函数定义为:
argmax
考虑示例
torch.argmax(input, dim=None, keepdim=False)
在这里,当我使用dim = 1而不是搜索列向量时,该函数将搜索行向量,如下所示。
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
就我的假设而言,dim = 0代表行,dim = 1代表列。
答案 0 :(得分:1)
是时候正确理解轴参数如何工作了:
理解上面的图片之后,
| v dim-0 ---> -----> dim-1 ------> -----> --------> dim-1 | [[-1.7739, 0.8073, 0.0472, -0.4084], v [ 0.6378, 0.6575, -1.2970, -0.0625], | [ 1.7970, -1.3463, 0.9011, -0.8704], v [ 1.5639, 0.7123, 0.0385, 1.8410]] | v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])