在火炬

时间:2016-05-28 23:48:52

标签: lua machine-learning torch loss

在Torch中实现自定义丢失功能的必要步骤是什么?

您似乎必须为updateOutput和updateGradInput编写实现。

这就是全部吗?那么你基本上创建了一个新类:

local CustomCriterion, parent =   torch.class('CustomCriterion','nn.Criterion')

并实现以下两个功能:

function CustomCriterion:updateOutput(input, target)
function CustomCriterion:updateGradInput(input, target)

这是正确的,还是还有更多工作要做?

另外,对于提供的标准,这些函数是用C实现的,但我想Lua实现也可以工作,虽然可能会慢一点?

1 个答案:

答案 0 :(得分:0)

我已经实现了表单的函数(伪代码)

--assuming input is partitioned in input_a,input_b
--         target is accordingly partitionend in target_a, target_b  
f(input)=MSE(input_a,target_a)+ custom_sutff(input_b,target_b)

只是你描述它的方式很多次。所以,据我所知,我认为你的两个问题的答案都是肯定的。

基本上nn/MSECriterion.luathis似乎支持这一点。