我目前正在使用Torch 7,我需要自定义丢失功能,尤其是交叉熵错误功能。
我正在考虑将一些参数添加到Cross Entropy错误函数中,但我无法找到应该修改的部分。
我看了一下CrossEntropyCriterion.lua,但仍然不知道我在这个文件中没有看到任何等式的方法。
谁能告诉我等式在哪里?或者我应该修改哪个文件?
答案 0 :(得分:0)
要自定义损失功能,您必须更改方法__init
,updateOutput
和updateGradInput
。
__init
是类初始化函数updateOutput
方法时,系统会调用:forward()
当您使用updateGradInput
并且它是您的标准的渐变时,将会调用:backward()
自定义标准的结构如下:
local yourCriterion, parent = torch.class('nn.yourCriterion', 'nn.Criterion')
function yourCriterion:__init(your_parameters):
parent.__init(self)
... (you can add as many parameters as you want to your criterion
and give them the name your prefer)
self.parameters = your_parameters
function yourCriterion:updateOutput(input)
... (your criterion code here)
return value_of_the_criterion
function yourCriterion:updateGradInput(input):
... (your criterion gradient code here)
return gradient
[编辑]:您可以在此处找到交叉熵标准的代码https://github.com/torch/nn/blob/master/CrossEntropyCriterion.lua