我是否需要在自定义丢失函数中定义backward()?

时间:2017-09-25 07:32:51

标签: python-2.7 torch pytorch

我已经定义了自己的损失函数。它确实有效。前馈可能没有问题。但我不确定它是否正确,因为我没有定义向后()。

class _Loss(nn.Module):
    def __init__(self, size_average=True):
        super(_Loss, self).__init__()
        self.size_average = size_average
class MyLoss(_Loss):
    def forward(self, input, target):
        loss = 0
        weight = np.zeros((BATCH_SIZE,BATCH_SIZE))
        for a in range(BATCH_SIZE):
            for b in range(BATCH_SIZE):
                weight[a][b] = get_weight(target.data[a][0])
        for i in range(BATCH_SIZE):
            for j in range(BATCH_SIZE):
                a_ij= (input[i]-input[j]-target[i]+target[j])*weight[i,j]
                loss += F.relu(a_ij)
        return loss

我想问的问题是

1)我是否需要定义reverse()to loss函数?

2)如何定义落后()?

3)在火炬中进行SGD时,有没有办法做数据索引?

1 个答案:

答案 0 :(得分:2)

您可以编写如下所示的损失函数。

def mse_loss(input, target):
            return ((input - target) ** 2).sum() / input.data.nelement() 

您不需要实现向后功能。损失函数的所有上述参数应为PyTorch变量,其余参数由torch.autograd函数处理。