在Torch中实现自定义丢失功能的必要步骤是什么?
您似乎必须为updateOutput和updateGradInput编写实现。
这就是全部吗?那么你基本上创建了一个新类:
local CustomCriterion, parent = torch.class('CustomCriterion','nn.Criterion')
并实现以下两个功能:
function CustomCriterion:updateOutput(input, target)
function CustomCriterion:updateGradInput(input, target)
这是正确的,还是还有更多工作要做?
另外,对于提供的标准,这些函数是用C实现的,但我想Lua实现也可以工作,虽然可能会慢一点?
答案 0 :(得分:0)
我已经实现了表单的函数(伪代码)
--assuming input is partitioned in input_a,input_b
-- target is accordingly partitionend in target_a, target_b
f(input)=MSE(input_a,target_a)+ custom_sutff(input_b,target_b)
只是你描述它的方式很多次。所以,据我所知,我认为你的两个问题的答案都是肯定的。
基本上nn/MSECriterion.lua和this似乎支持这一点。