PyTorch中的生成对抗网络(GAN)培训生成器

时间:2020-06-06 15:45:27

标签: machine-learning deep-learning pytorch generative-adversarial-network

我正在努力在PyTorch 1.5.0中实现生殖对抗网络(GAN)。

为了计算生成器的损失,我同时计算了鉴别器将全实型迷你批次和全(生成器生成的)假迷你批次误分类的负概率。然后,我依次对这两个部分进行反向传播,最后应用阶跃函数。

计算和反向传播部分损失,该损失是所生成假数据的错误分类的函数,因为在该损失项的反向传播期间,反向路径会通过生成器首先产生了虚假数据。

但是,对所有实数据微型批次的分类不涉及通过生成器传递数据。因此,我想知道以下代码是否仍会为生成器计算梯度,或者根本不会计算任何梯度(因为向后路径不会穿过生成器,并且在更新生成器时鉴别器处于eval模式)?

# Update generator #
net.generator.train()
net.discriminator.eval()
net.generator.zero_grad()

# All-real minibatch
x_real = get_all_real_minibatch()
y_true = torch.full((batch_size,), label_fake).long()  # Pretend true targets were fake
y_pred = net.discriminator(x_real)  # Produces softmax probability distribution over (0=label_fake,1=label_real)

loss_real = NLLLoss(torch.log(y_pred), y_true) 
loss_real.backward()
optimizer_generator.step()

如果这无法正常工作,我该如何使它工作?预先感谢!

1 个答案:

答案 0 :(得分:1)

由于没有使用生成器的任何参数进行计算,因此不会将梯度传播到生成器。鉴别器处于评估模式不会阻止梯度传播到生成器,尽管如果您使用的层与评估模式相比,在评估模式下的行为与火车模式不同,例如降落,则它们会略有不同。

对真实图像进行错误分类不是训练生成器的一部分,因为它不会从此信息中获得任何收益。从概念上讲,生成器应该从鉴别器无法正确分类真实图像这一事实中学到什么?生成器的唯一任务是创建一个伪造的图像,使辨别者认为它是真实的,因此,生成器唯一相关的信息是辨别器是否能够识别伪造的图像。如果鉴别者确实能够识别伪造的图像,则生成器需要进行自我调整以创建更具说服力的伪造。

当然,这不是二进制情​​况,但是生成器总是尝试改进伪造的图像,以使鉴别器更加确信它是真实的图像。生成器的目的不是要使鉴别器成为可疑的(概率为0.5,是真实的还是假的),而是使鉴别器完全相信它是真实的,即使它是假的。这就是为什么他们是对抗性的,而不是合作性的。