在Pytorch中训练GAN

时间:2019-02-19 21:38:03

标签: machine-learning pytorch

我试图用两个不同的损失函数分别训练生成器和鉴别器。这是计算损耗和梯度的代码。

    A_pre_B = netG_A2B(noisy_A)
    pixel_loss_A2B_A = criterion_identity(synth_C,A_pre_B)  * 0.01

    pred_fake = netD_B(A_pre_B) 
    loss_GAN_A2B = criterion_GAN(pred_fake, target_real) 

    loss_G = pixel_loss_A2B_A + loss_GAN_A2B 
    loss_G.backward()
    optimizer_G.step()

    optimizer_D_B.zero_grad()
    fake_B = A_pre_B.detach()
    pred_fake = netD_B(fake_B)
    loss_D_fake = criterion_GAN(pred_fake, target_fake)

    loss_D_B = (loss_D_real + loss_D_fake)*0.5
    loss_D_B.backward()
    optimizer_D_B.step()

我读到here,在计算反向传播的梯度时,应保持发生器和鉴别器的权重保持不变。培训融合得不好。这个 answer说,在更新步骤中,我的生成器优化器不会更新鉴别器权重。但是当前我调用A_pre_B.detach()并从生成器中重新获取pred_fake。为什么我不能在这里调用pred_fake.detach()?我这样做时会收到内存泄漏错误。

0 个答案:

没有答案
相关问题