如何在Pytorch中实现上限JSD损失?

时间:2017-12-13 06:35:04

标签: python deep-learning pytorch

我正在尝试使用pytorch“复制”TextGAN而我是pytorch的新手。我目前关注的是复制L_G(方程式7第3页),这是我目前的代码:

class JSDLoss(nn.Module):

    def __init__(self):
        super(JSDLoss,self).__init__()

    def forward(self, batch_size, f_real, f_synt):
        assert f_real.size()[1] == f_synt.size()[1]

        f_num_features = f_real.size()[1]
        identity = autograd.Variable(torch.eye(f_num_features)*0.1, requires_grad=False)

        if use_cuda:
            identity = identity.cuda(gpu)

        f_real_mean = torch.mean(f_real, 0, keepdim=True)
        f_synt_mean = torch.mean(f_synt, 0, keepdim=True)

        dev_f_real = f_real - f_real_mean.expand(batch_size,f_num_features)
        dev_f_synt = f_synt - f_synt_mean.expand(batch_size,f_num_features)

        f_real_xx = torch.mm(torch.t(dev_f_real), dev_f_real)
        f_synt_xx = torch.mm(torch.t(dev_f_synt), dev_f_synt)

        cov_mat_f_real = (f_real_xx / batch_size) - torch.mm(f_real_mean, torch.t(f_real_mean)) + identity
        cov_mat_f_synt = (f_synt_xx / batch_size) - torch.mm(f_synt_mean, torch.t(f_synt_mean)) + identity

        cov_mat_f_real_inv = torch.inverse(cov_mat_f_real)
        cov_mat_f_synt_inv = torch.inverse(cov_mat_f_synt)

        temp1 = torch.trace(torch.add(torch.mm(cov_mat_f_synt_inv, cov_mat_f_real), torch.mm(cov_mat_f_real_inv, cov_mat_f_synt)))
        temp1 = temp1.view(1,1)
        temp2 = torch.mm(torch.mm((f_synt_mean - f_real_mean), (cov_mat_f_synt_inv + cov_mat_f_real_inv)), torch.t(f_synt_mean - f_real_mean))
        loss_g = torch.add(temp1, temp2).mean()

        return loss_g

有效。但是,我怀疑这不是创建自定义损失的方法。任何形式的帮助都非常感谢!在此先感谢:)

1 个答案:

答案 0 :(得分:0)

如何在Pytorch中创建自定义丢失

这是你在Pytorch中创建自定义丢失的方法。您需要满足以下要求:

  
      
  • 最终由损失函数返回的值必须是标量值。不是矢量/张量。
  •   
  • 返回的值必须是变量。这样它可以用于更新模型中的参数。这样做的最佳方式   是确保传入的x和y都是变量。   这样,两者的任何功能也都是变量。
  •   
  • 定义__init__forward方法
  •   

您可以在Pytorch源代码中找到几个可用作示例的损耗模块:https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py

如果您将迷你批量张量传递给损失函数,则无需将小批量大小传递给forward函数,因为可以在{{1}中计算大小功能。

如何使用自定义丢失

一旦实现了损失功能,您可以按如下方式使用它,例如:

forward

loss = YourLoss() input = autograd.Variable(torch.randn(3, 5), requires_grad=True) target = autograd.Variable(torch.randn(3, 5)) output = loss(input, target) output.backward() 为您网络中loss.backward()的每个参数dloss/dx计算x。对于每个参数requires_grad=True,这些累积到x.grad。在伪代码中:

x

optimizer.step使用渐变x.grad += dloss/dx 更新x的值。例如,SGD优化器执行:

x.grad

x += -lr * x.grad 为优化程序中的每个参数optimizer.zero_grad()清除x.grad。在x之前调用它是很重要的,否则你将累积多次传递的渐变。