因此,我正在创建一个pytorch模型,对于正向传递,我将应用正向传递方法来获取分数张量,其中包含每个类的预测分数。该张量的形状为[100,10]。现在,我想通过将其与包含实际分数的y进行比较来获得准确性。该张量具有形状[100]。为了比较两者,我将使用torch.mean(scores == y)
,然后计算出多少个相同。
问题是我需要转换分数张量,以便每一行仅包含每一行中最高值的索引。例如,如果张量看起来像这样,
tensor(
[[0.3232, -0.2321, 0.2332, -0.1231, 0.2435, 0.6728],
[0.2323, -0.1231, -0.5321, -0.1452, 0.5435, 0.1722],
[0.9823, -0.1321, -0.6433, 0.1231, 0.023, 0.0711]]
)
然后我希望将其转换为看起来像这样的形式。
tensor([5, 4, 0])
我该怎么做?
答案 0 :(得分:0)
将argmax
与所需的dim
(又称为轴)一起使用
a = tensor(
[[0.3232, -0.2321, 0.2332, -0.1231, 0.2435, 0.6728],
[0.2323, -0.1231, -0.5321, -0.1452, 0.5435, 0.1722],
[0.9823, -0.1321, -0.6433, 0.1231, 0.023, 0.0711]]
)
a.argmax(1)
# tensor([ 5, 4, 0])