如何在Torch中自定义丢失功能?

时间:2017-03-06 06:03:26

标签: torch

我目前正在使用Torch 7,我需要自定义丢失功能,尤其是交叉熵错误功能。

我正在考虑将一些参数添加到Cross Entropy错误函数中,但我无法找到应该修改的部分。

我看了一下CrossEntropyCriterion.lua,但仍然不知道我在这个文件中没有看到任何等式的方法。

谁能告诉我等式在哪里?或者我应该修改哪个文件?

1 个答案:

答案 0 :(得分:0)

要自定义损失功能,您必须更改方法__initupdateOutputupdateGradInput

  • __init是类初始化函数
  • 在您的标准中使用updateOutput方法时,系统会调用
  • :forward() 当您使用updateGradInput并且它是您的标准的渐变时,将会调用
  • :backward()

自定义标准的结构如下:

local yourCriterion, parent = torch.class('nn.yourCriterion', 'nn.Criterion')

function yourCriterion:__init(your_parameters):
  parent.__init(self)
  ... (you can add as many parameters as you want to your criterion
       and give them the name your prefer)
  self.parameters = your_parameters


function yourCriterion:updateOutput(input)
  ... (your criterion code here)
  return value_of_the_criterion

function yourCriterion:updateGradInput(input):
  ... (your criterion gradient code here)
  return gradient

[编辑]:您可以在此处找到交叉熵标准的代码https://github.com/torch/nn/blob/master/CrossEntropyCriterion.lua