如何找到张量对象中每一行的最大索引?

时间:2019-02-14 01:08:12

标签: python python-3.x tensorflow pytorch

因此,我正在创建一个pytorch模型,对于正向传递,我将应用正向传递方法来获取分数张量,其中包含每个类的预测分数。该张量的形状为[100,10]。现在,我想通过将其与包含实际分数的y进行比较来获得准确性。该张量具有形状[100]。为了比较两者,我将使用torch.mean(scores == y),然后计算出多少个相同。

问题是我需要转换分数张量,以便每一行仅包含每一行中最高值的索引。例如,如果张量看起来像这样,

tensor(
    [[0.3232, -0.2321, 0.2332, -0.1231, 0.2435, 0.6728],

    [0.2323, -0.1231, -0.5321, -0.1452, 0.5435, 0.1722],

    [0.9823, -0.1321, -0.6433, 0.1231, 0.023, 0.0711]]
)

然后我希望将其转换为看起来像这样的形式。

tensor([5, 4, 0])

我该怎么做?

1 个答案:

答案 0 :(得分:0)

argmax与所需的dim(又称为轴)一起使用

a = tensor(
    [[0.3232, -0.2321, 0.2332, -0.1231, 0.2435, 0.6728],
    [0.2323, -0.1231, -0.5321, -0.1452, 0.5435, 0.1722],
    [0.9823, -0.1321, -0.6433, 0.1231, 0.023, 0.0711]]
)

a.argmax(1)
# tensor([ 5,  4,  0])