将我的自定义损失功能添加到火炬

时间:2015-11-11 10:38:42

标签: torch

我想为火炬添加一个损失函数来计算预测值和目标值之间的编辑距离。 有没有一种简单的方法来实现这个想法? 或者我是否必须使用向后和向前功能编写自己的类?

2 个答案:

答案 0 :(得分:6)

如果您的标准可以表示为现有模块和标准的组合,那么使用容器简单地构造这样的组合是个好主意。唯一的问题是标准容器只能用于模块而不是标准。不同之处在于:forward方法签名:

module:forward(input)
criterion:forward(input, target)

幸运的是,我们可以自由定义我们自己的容器,它也可以使用标准。例如,顺序:

local GeneralizedSequential, _ = torch.class('nn.GeneralizedSequential', 'nn.Sequential')

function GeneralizedSequential:forward(input, target)
    return self:updateOutput(input, target)
end

function GeneralizedSequential:updateOutput(input, target)
    local currentOutput = input
    for i=1,#self.modules do
        currentOutput = self.modules[i]:updateOutput(currentOutput, target)
    end
    self.output = currentOutput
    return currentOutput
end

下面是如何实现具有此通用顺序容器的nn.CrossEntropyCriterion的说明:

function MyCrossEntropyCriterion(weights)
    criterion = nn.GeneralizedSequential()
    criterion:add(nn.LogSoftMax())
    criterion:add(nn.ClassNLLCriterion(weights))
    return criterion
end

检查一切是否正确:

output = torch.rand(3,3)
target = torch.Tensor({1, 2, 3})

mycrit = MyCrossEntropyCriterion()
-- print(mycrit)
print(mycrit:forward(output, target))
print(mycrit:backward(output, target))

crit = nn.CrossEntropyCriterion()
-- print(crit)
print(crit:forward(output, target))
print(crit:backward(output, target))

答案 1 :(得分:0)

只是要添加到接受的答案中,您必须注意您定义的损失函数(在您的情况下编辑距离)相对于网络参数是可区分的。