这是我第一次与GAN一起工作,而且我面临着一个问题,那就是鉴别器一再地超越生成器。我正在尝试从this article复制PA
模型,并且正在寻找this slightly different implementation来帮助我。
我已经阅读了很多有关GAN的工作方式的论文,并且还遵循了一些教程来更好地理解它们。此外,我已经阅读了有关如何克服主要不稳定因素的文章,但找不到解决这种现象的方法。
在我的环境中,我正在使用PyTorch
和BCELoss()
。在DCGAN PyTorch tutorial之后,我正在使用以下训练循环:
criterion = nn.BCELoss()
train_d = False
# Discriminator true
optim_d.zero_grad()
disc_train_real = target.to(device)
batch_size = disc_train_real.size(0)
label = torch.full((batch_size,), 1, device=device).cuda()
output_d = discriminator(disc_train_real).view(-1)
loss_d_real = criterion(output_d, label).cuda()
if lossT:
loss_d_real *= 2
if loss_d_real.item() > 0.3:
loss_d_real.backward()
train_d = True
D_x = output_d.mean().item()
# Discriminator false
output_g = generator(image)
output_d = discriminator(output_g.detach()).view(-1)
label.fill_(0)
loss_d_fake = criterion(output_d, label).cuda()
D_G_z1 = output_d.mean().item()
if lossT:
loss_d_fake *= 2
loss_d = loss_d_real + loss_d_fake
if loss_d_fake.item() > 0.3:
loss_d_fake.backward()
train_d = True
if train_d:
optim_d.step()
# Generator
label.fill_(1)
output_d = discriminator(output_g).view(-1)
loss_g = criterion(output_d, label).cuda()
D_G_z2 = output_d.mean().item()
if lossT:
loss_g *= 2
loss_g.backward()
optim_g.step()
,经过一段时间的解决,一切似乎都很好:
Epoch 1/5 - Step: 1900/9338 Loss G: 3.057388 Loss D: 0.214545 D(x): 0.940985 D(G(z)): 0.114064 / 0.114064
Time for the last step: 51.55 s Epoch ETA: 01:04:13
Epoch 1/5 - Step: 2000/9338 Loss G: 2.984724 Loss D: 0.222931 D(x): 0.879338 D(G(z)): 0.159163 / 0.159163
Time for the last step: 52.68 s Epoch ETA: 01:03:24
Epoch 1/5 - Step: 2100/9338 Loss G: 2.824713 Loss D: 0.241953 D(x): 0.905837 D(G(z)): 0.110231 / 0.110231
Time for the last step: 50.91 s Epoch ETA: 01:02:29
Epoch 1/5 - Step: 2200/9338 Loss G: 2.807455 Loss D: 0.252808 D(x): 0.908131 D(G(z)): 0.218515 / 0.218515
Time for the last step: 51.72 s Epoch ETA: 01:01:37
Epoch 1/5 - Step: 2300/9338 Loss G: 2.470529 Loss D: 0.569696 D(x): 0.620966 D(G(z)): 0.512615 / 0.350175
Time for the last step: 51.96 s Epoch ETA: 01:00:46
Epoch 1/5 - Step: 2400/9338 Loss G: 2.148863 Loss D: 1.071563 D(x): 0.809529 D(G(z)): 0.114487 / 0.114487
Time for the last step: 51.59 s Epoch ETA: 00:59:53
Epoch 1/5 - Step: 2500/9338 Loss G: 2.016863 Loss D: 0.904711 D(x): 0.621433 D(G(z)): 0.440721 / 0.435932
Time for the last step: 52.03 s Epoch ETA: 00:59:02
Epoch 1/5 - Step: 2600/9338 Loss G: 2.495639 Loss D: 0.949308 D(x): 0.671085 D(G(z)): 0.557924 / 0.420826
Time for the last step: 52.66 s Epoch ETA: 00:58:12
Epoch 1/5 - Step: 2700/9338 Loss G: 2.519842 Loss D: 0.798667 D(x): 0.775738 D(G(z)): 0.246357 / 0.265839
Time for the last step: 51.20 s Epoch ETA: 00:57:19
Epoch 1/5 - Step: 2800/9338 Loss G: 2.545630 Loss D: 0.756449 D(x): 0.895455 D(G(z)): 0.403628 / 0.301851
Time for the last step: 51.88 s Epoch ETA: 00:56:27
Epoch 1/5 - Step: 2900/9338 Loss G: 2.458109 Loss D: 0.653513 D(x): 0.820105 D(G(z)): 0.379199 / 0.103250
Time for the last step: 53.50 s Epoch ETA: 00:55:39
Epoch 1/5 - Step: 3000/9338 Loss G: 2.030103 Loss D: 0.948208 D(x): 0.445385 D(G(z)): 0.303225 / 0.263652
Time for the last step: 51.57 s Epoch ETA: 00:54:47
Epoch 1/5 - Step: 3100/9338 Loss G: 1.721604 Loss D: 0.949721 D(x): 0.365646 D(G(z)): 0.090072 / 0.232912
Time for the last step: 52.19 s Epoch ETA: 00:53:55
Epoch 1/5 - Step: 3200/9338 Loss G: 1.438854 Loss D: 1.142182 D(x): 0.768163 D(G(z)): 0.321164 / 0.237878
Time for the last step: 50.79 s Epoch ETA: 00:53:01
Epoch 1/5 - Step: 3300/9338 Loss G: 1.924418 Loss D: 0.923860 D(x): 0.729981 D(G(z)): 0.354812 / 0.318090
Time for the last step: 52.59 s Epoch ETA: 00:52:11
,即,生成器上的梯度较高,并在一段时间后开始减小,与此同时,鉴别器上的梯度上升。至于损失,发生器下降而鉴别器上升。如果与本教程相比,我想这可以接受。
这是我的第一个问题:我注意到,在本教程中(通常)随着D_G_z1
的增加,D_G_z2
的减少(反之亦然),而在我的示例中发生的次数要少得多。只是巧合还是我做错了什么?
鉴于此,我让培训过程继续进行,但现在我注意到了这一点:
Epoch 3/5 - Step: 1100/9338 Loss G: 4.071329 Loss D: 0.031608 D(x): 0.999969 D(G(z)): 0.024329 / 0.024329
Time for the last step: 51.41 s Epoch ETA: 01:11:24
Epoch 3/5 - Step: 1200/9338 Loss G: 3.883331 Loss D: 0.036354 D(x): 0.999993 D(G(z)): 0.043874 / 0.043874
Time for the last step: 51.63 s Epoch ETA: 01:10:29
Epoch 3/5 - Step: 1300/9338 Loss G: 3.468963 Loss D: 0.054542 D(x): 0.999972 D(G(z)): 0.050145 / 0.050145
Time for the last step: 52.47 s Epoch ETA: 01:09:40
Epoch 3/5 - Step: 1400/9338 Loss G: 3.504971 Loss D: 0.053683 D(x): 0.999972 D(G(z)): 0.052180 / 0.052180
Time for the last step: 50.75 s Epoch ETA: 01:08:41
Epoch 3/5 - Step: 1500/9338 Loss G: 3.437765 Loss D: 0.056286 D(x): 0.999941 D(G(z)): 0.058839 / 0.058839
Time for the last step: 52.20 s Epoch ETA: 01:07:50
Epoch 3/5 - Step: 1600/9338 Loss G: 3.369209 Loss D: 0.062133 D(x): 0.955688 D(G(z)): 0.058773 / 0.058773
Time for the last step: 51.05 s Epoch ETA: 01:06:54
Epoch 3/5 - Step: 1700/9338 Loss G: 3.290109 Loss D: 0.065704 D(x): 0.999975 D(G(z)): 0.056583 / 0.056583
Time for the last step: 51.27 s Epoch ETA: 01:06:00
Epoch 3/5 - Step: 1800/9338 Loss G: 3.286248 Loss D: 0.067969 D(x): 0.993238 D(G(z)): 0.063815 / 0.063815
Time for the last step: 52.28 s Epoch ETA: 01:05:09
Epoch 3/5 - Step: 1900/9338 Loss G: 3.263996 Loss D: 0.065335 D(x): 0.980270 D(G(z)): 0.037717 / 0.037717
Time for the last step: 51.59 s Epoch ETA: 01:04:16
Epoch 3/5 - Step: 2000/9338 Loss G: 3.293503 Loss D: 0.065291 D(x): 0.999873 D(G(z)): 0.070188 / 0.070188
Time for the last step: 51.85 s Epoch ETA: 01:03:25
Epoch 3/5 - Step: 2100/9338 Loss G: 3.184164 Loss D: 0.070931 D(x): 0.999971 D(G(z)): 0.059657 / 0.059657
Time for the last step: 52.14 s Epoch ETA: 01:02:34
Epoch 3/5 - Step: 2200/9338 Loss G: 3.116310 Loss D: 0.080597 D(x): 0.999850 D(G(z)): 0.074931 / 0.074931
Time for the last step: 51.85 s Epoch ETA: 01:01:42
Epoch 3/5 - Step: 2300/9338 Loss G: 3.142180 Loss D: 0.073999 D(x): 0.995546 D(G(z)): 0.054752 / 0.054752
Time for the last step: 51.76 s Epoch ETA: 01:00:50
Epoch 3/5 - Step: 2400/9338 Loss G: 3.185711 Loss D: 0.072601 D(x): 0.999992 D(G(z)): 0.076053 / 0.076053
Time for the last step: 50.53 s Epoch ETA: 00:59:54
Epoch 3/5 - Step: 2500/9338 Loss G: 3.027437 Loss D: 0.083906 D(x): 0.997390 D(G(z)): 0.082501 / 0.082501
Time for the last step: 52.06 s Epoch ETA: 00:59:03
Epoch 3/5 - Step: 2600/9338 Loss G: 3.052374 Loss D: 0.085030 D(x): 0.999924 D(G(z)): 0.073295 / 0.073295
Time for the last step: 52.37 s Epoch ETA: 00:58:12
不仅D(x)
再次增加并且几乎保持不变,而且D_G_z1
和D_G_z2
始终显示相同的值。此外,从损失的角度看,歧视者的表现显然好于产生者。这种行为一直持续到下一个时期,直到训练结束。
因此,我的第二个问题:这正常吗?如果没有,我在程序中做错了什么?如何获得更稳定的培训?
编辑:我尝试按照建议使用MSELoss()
来训练网络,这是输出:
Epoch 1/1 - Step: 100/9338 Loss G: 0.800785 Loss D: 0.404525 D(x): 0.844653 D(G(z)): 0.030439 / 0.016316
Time for the last step: 55.22 s Epoch ETA: 01:25:01
Epoch 1/1 - Step: 200/9338 Loss G: 1.196659 Loss D: 0.014051 D(x): 0.999970 D(G(z)): 0.006543 / 0.006500
Time for the last step: 51.41 s Epoch ETA: 01:21:11
Epoch 1/1 - Step: 300/9338 Loss G: 1.197319 Loss D: 0.000806 D(x): 0.999431 D(G(z)): 0.004821 / 0.004724
Time for the last step: 51.79 s Epoch ETA: 01:19:32
Epoch 1/1 - Step: 400/9338 Loss G: 1.198960 Loss D: 0.000720 D(x): 0.999612 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.47 s Epoch ETA: 01:18:09
Epoch 1/1 - Step: 500/9338 Loss G: 1.212810 Loss D: 0.000021 D(x): 0.999938 D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.18 s Epoch ETA: 01:17:11
Epoch 1/1 - Step: 600/9338 Loss G: 1.216168 Loss D: 0.000000 D(x): 0.999945 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.24 s Epoch ETA: 01:16:02
Epoch 1/1 - Step: 700/9338 Loss G: 1.212301 Loss D: 0.000000 D(x): 0.999970 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.61 s Epoch ETA: 01:15:02
Epoch 1/1 - Step: 800/9338 Loss G: 1.214397 Loss D: 0.000005 D(x): 0.999973 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.58 s Epoch ETA: 01:14:04
Epoch 1/1 - Step: 900/9338 Loss G: 1.212016 Loss D: 0.000003 D(x): 0.999932 D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.20 s Epoch ETA: 01:13:13
Epoch 1/1 - Step: 1000/9338 Loss G: 1.215162 Loss D: 0.000000 D(x): 0.999988 D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.28 s Epoch ETA: 01:12:23
Epoch 1/1 - Step: 1100/9338 Loss G: 1.216291 Loss D: 0.000000 D(x): 0.999983 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.78 s Epoch ETA: 01:11:28
Epoch 1/1 - Step: 1200/9338 Loss G: 1.215526 Loss D: 0.000000 D(x): 0.999978 D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.88 s Epoch ETA: 01:10:35
可以看出,情况变得更糟。此外,再次阅读EnhanceNet paper,第4.2.4节(专业训练)指出,所使用的对抗损失函数是BCELoss()
,因为我希望能解决我得到的消失梯度问题MSELoss()
。
答案 0 :(得分:1)
解释GAN损失是一件荒唐的事,因为实际损失值
问题1:(根据我的经验),鉴别器/发电机主导权之间摆动的频率主要基于以下几个因素:学习率和批量大小,这将影响传播的损失。使用的特定损耗度量将影响D&G网络训练方式的差异。 EnhanceNet论文(用于基线)和本教程也使用均方误差损失-您正在使用二进制交叉熵损失,这将改变网络的收敛速度。我不是专家,所以这里的link to Rohan Varma's article that explains the difference between loss functions非常好。奇怪的是,看看您更改丢失功能时网络的行为是否有所不同-试试看并更新问题?
问题2:随着时间的流逝,D损失和G损失都应该稳定在一个值上,但是很难判断他们是否已经在强大的绩效上趋于一致或是否趋于一致。它们之所以收敛,是因为模式崩溃/梯度递减(Jonathan Hui's explanation on problems in training GANs)之类的。我发现最好的方法是实际检查生成的图像的横截面,然后目视检查输出,或者对生成的图像集使用某种感知指标(SSIM,PSNR,PIQ等)。
一些其他有用的线索,可能会有助于您找到答案:
This post在解释GAN损失方面有两个相当不错的指针。
伊恩·古德费洛(Ian Goodfellow)的NIPS2016 tutorial对于如何平衡D&G培训也有一些扎实的想法。