MSE损失功能在pytorch中如何工作?

时间:2018-12-04 13:49:56

标签: python pytorch loss-function mse

在以下代码中(从SentEval中提取),定义了一个神经网络结构,该结构将1024个实数映射到5个输出预测。问题在于评估两个句子之间的相关性(每个句子都具有512个特征)。相关性是[1,5]中的数字。我认为,如果训练相关性数在{1,2,3,4,5}中,cross entropy是一个更好的损失函数,但是由于在训练集中我们在[1,5]中具有真实的相关性数,MSE用作损失函数。

问题:由于网络为每个输入输出5个概率数,如何在实数和5个概率数之间计算MSE

from torch import nn

inputdim = 1024
nclasses = 5
model = nn.Sequential(
            nn.Linear(inputdim, nclasses),
            nn.Softmax(dim=-1),
        )
loss_fn = nn.MSELoss()

0 个答案:

没有答案