修改火炬标准

时间:2017-06-02 16:58:54

标签: torch loss

我想在Torch中创建一个自定义丢失函数,它是ClassNLLCriterion的修改。具体而言,ClassNLLCriterion损失是:

loss(x, class) = -x[class]

我想将其修改为:

loss(x, class) = -x[class]*K

其中K是网络输入的函数,而不是网络权重或网络输出。因此K可以视为常量。

实施此自定义条件的最简单方法是什么? updateOutput()函数似乎很简单,但如何修改updateGradInput()函数?

1 个答案:

答案 0 :(得分:1)

基本上你的损失函数L是输入和目标的函数。你有

loss(input, target) = ClassNLLCriterion(input, target) * K

如果我理解你的新损失。那么你想实现updateGradInput,它返回你的损失函数相对于输入的导数,这是

updateGradInput[ClassNLLCriterion](input, target) * K + ClassNLLCriterion(input, target) * dK/dinput

因此,您只需计算损失函数输入的K的导数(您没有给我们计算K的公式)并将其插入上一行。由于您的新损失功能依赖于ClassNLLCriterion,因此您可以使用此损失函数的updateGradInputupdateOutput来计算您的损失。