WGAN-GP的正损失越来越大

时间:2018-11-26 10:50:45

标签: python machine-learning deep-learning computer-vision pytorch

我正在研究在PyTorch中使用具有梯度损失的Wasserstein GAN,但始终会产生较大的正发电机损失,并随着时间的推移而增加。 我从Caogang's implementation大量借用,但是使用this implementation中使用的鉴别器和生成器损耗,因为如果尝试使用{{1 }}和Invalid gradient at index 0 - expected shape[] but got [1]参数在草岗实施中使用。

我正在接受增强的WikiArt数据集(> 400k 64x64图像)和CIFAR-10的训练,并且获得了正常的WGAN(具有权重裁剪功能)[即它会在25个纪元后生成可传递的图像],尽管对于所有纪元,D和G损耗都徘徊在3左右[我使用.backward()等来计算它们)。但是,在WGAN-GP版本中,发电机损耗在WikiArt和CIFAR-10数据集上均急剧增加,并且完全无法在WikiArt上产生噪声。

以下是在CIFAR-10上经过25个时代后的损失示例: WGAN-GP loss

我不使用单侧标签平滑之类的技巧,并且使用默认学习率0.001进行训练,使用Adam优化器,并且每次生成器更新时,对鉴别器进行5次训练。为什么会发生这种疯狂的丢失行为,为什么正常的减肥瘦身WGAN在WikiArt上仍然可以“工作”,但是WGANGP完全失败了?

无论结构如何,无论G和D都是DCGAN还是使用this modified DCGAN, the Creative Adversarial Network时,这种情况都会发生,这要求D能够对图像进行分类并且G生成模糊图像。

以下是我当前的one方法的相关部分:

mone

这是(DCGAN)生成器的代码:

torch.mean(D_real)

这是(当前)CAN鉴别器,它具有用于 样式(图像类)分类):

train

根据我的WGAN论文,WGANGP版本和我的GAN的WGAN版本之间唯一的区别是WGAN版本使用self.generator = Can64Generator(self.z_noise, self.channels, self.num_gen_filters).to(self.device) self.discriminator =WCan64Discriminator(self.channels,self.y_dim, self.num_disc_filters).to(self.device) style_criterion = nn.CrossEntropyLoss() self.disc_optimizer = optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.9)) self.gen_optimizer = optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.9)) while i < len(dataloader): j = 0 disc_loss_epoch = [] gen_loss_epoch = [] if self.type == "can": disc_class_loss_epoch = [] gen_class_loss_epoch = [] if self.gradient_penalty == False: # critic training methodology in official WGAN implementation if gen_iterations < 25 or (gen_iterations % 500 == 0): disc_iters = 100 else: disc_iters = self.disc_iterations while j < disc_iters and (i < len(dataloader)): # if using wgan with weight clipping if self.gradient_penalty == False: # Train Discriminator for param in self.discriminator.parameters(): param.data.clamp_(self.lower_clamp,self.upper_clamp) for param in self.discriminator.parameters(): param.requires_grad_(True) j+=1 i+=1 data = data_iterator.next() self.discriminator.zero_grad() real_images, image_labels = data # image labels are the the image's classes (e.g. Impressionism) real_images = real_images.to(self.device) batch_size = real_images.size(0) real_image_labels = torch.LongTensor(batch_size).to(self.device) real_image_labels.copy_(image_labels) labels = torch.full((batch_size,),real_label,device=self.device) if self.type == 'can': predicted_output_real, predicted_styles_real = self.discriminator(real_images.detach()) predicted_styles_real = predicted_styles_real.to(self.device) disc_class_loss = style_criterion(predicted_styles_real,real_image_labels) disc_class_loss.backward(retain_graph=True) else: predicted_output_real = self.discriminator(real_images.detach()) disc_loss_real = -torch.mean(predicted_output_real) # fake noise = torch.randn(batch_size,self.z_noise,1,1,device=self.device) with torch.no_grad(): noise_g = noise.detach() fake_images = self.generator(noise_g) labels.fill_(fake_label) if self.type == 'can': predicted_output_fake, predicted_styles_fake = self.discriminator(fake_images) else: predicted_output_fake = self.discriminator(fake_images) disc_gen_z_1 = predicted_output_fake.mean().item() disc_loss_fake = torch.mean(predicted_output_fake) #via https://github.com/znxlwm/pytorch-generative-model-collections/blob/master/WGAN_GP.py if self.gradient_penalty: # gradient penalty alpha = torch.rand((real_images.size()[0], 1, 1, 1)).to(self.device) x_hat = alpha * real_images.data + (1 - alpha) * fake_images.data x_hat.requires_grad_(True) if self.type == 'can': pred_hat, _ = self.discriminator(x_hat) else: pred_hat = self.discriminator(x_hat) gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).to(self.device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_penalty = lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean() disc_loss = disc_loss_fake + disc_loss_real + gradient_penalty else: disc_loss = disc_loss_fake + disc_loss_real if self.type == 'can': disc_loss += disc_class_loss.mean() disc_x = disc_loss.mean().item() disc_loss.backward(retain_graph=True) self.disc_optimizer.step() # train generator for param in self.discriminator.parameters(): param.requires_grad_(False) self.generator.zero_grad() labels.fill_(real_label) if self.type == 'can': predicted_output_fake, predicted_styles_fake = self.discriminator(fake_images) predicted_styles_fake = predicted_styles_fake.to(self.device) else: predicted_output_fake = self.discriminator(fake_images) gen_loss = -torch.mean(predicted_output_fake) disc_gen_z_2 = gen_loss.mean().item() if self.type == 'can': fake_batch_labels = 1.0/self.y_dim * torch.ones_like(predicted_styles_fake) fake_batch_labels = torch.mean(fake_batch_labels,1).long().to(self.device) gen_class_loss = style_criterion(predicted_styles_fake,fake_batch_labels) gen_class_loss.backward(retain_graph=True) gen_loss += gen_class_loss.mean() gen_loss.backward() gen_iterations += 1 class Can64Generator(nn.Module): def __init__(self, z_noise, channels, num_gen_filters): super(Can64Generator,self).__init__() self.ngpu = 1 self.main = nn.Sequential( nn.ConvTranspose2d(z_noise, num_gen_filters * 16, 4, 1, 0, bias=False), nn.BatchNorm2d(num_gen_filters * 16), nn.ReLU(True), nn.ConvTranspose2d(num_gen_filters * 16, num_gen_filters * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(num_gen_filters * 4), nn.ReLU(True), nn.ConvTranspose2d(num_gen_filters * 4, num_gen_filters * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(num_gen_filters * 2), nn.ReLU(True), nn.ConvTranspose2d(num_gen_filters * 2, num_gen_filters, 4, 2, 1, bias=False), nn.BatchNorm2d(num_gen_filters), nn.ReLU(True), nn.ConvTranspose2d(num_gen_filters, 3, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, inp): output = self.main(inp) return output 并削减了鉴别符的权重。

可能是什么原因造成的?我想进行尽可能小的更改,因为我想单独比较损失函数。即使在CIFAR-10上使用未经修改的DCGAN鉴别器时,也会遇到相同的问题。我是否遇到这种情况,可能是因为我目前仅训练25个纪元,还是有其他原因?有趣的是,当我使用LSGAN(class Can64Discriminator(nn.Module): def __init__(self, channels,y_dim, num_disc_filters): super(Can64Discriminator, self).__init__() self.ngpu = 1 self.conv = nn.Sequential( nn.Conv2d(channels, num_disc_filters // 2, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(num_disc_filters // 2, num_disc_filters, 4, 2, 1, bias=False), nn.BatchNorm2d(num_disc_filters), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(num_disc_filters, num_disc_filters * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(num_disc_filters * 2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(num_disc_filters * 2, num_disc_filters * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(num_disc_filters * 4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(num_disc_filters * 4, num_disc_filters * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(num_disc_filters * 8), nn.LeakyReLU(0.2, inplace=True), ) # was this #self.final_conv = nn.Conv2d(num_disc_filters * 8, num_disc_filters * 8, 4, 2, 1, bias=False) self.real_fake_head = nn.Linear(num_disc_filters * 8, 1) # no bn and lrelu needed self.sig = nn.Sigmoid() self.fc = nn.Sequential() self.fc.add_module("linear_layer{0}".format(num_disc_filters*16),nn.Linear(num_disc_filters*8,num_disc_filters*16)) self.fc.add_module("linear_layer{0}".format(num_disc_filters*8),nn.Linear(num_disc_filters*16,num_disc_filters*8)) self.fc.add_module("linear_layer{0}".format(num_disc_filters),nn.Linear(num_disc_filters*8,y_dim)) self.fc.add_module('softmax',nn.Softmax(dim=1)) def forward(self, inp): x = self.conv(inp) x = x.view(x.size(0),-1) real_out = self.sig(self.real_fake_head(x)) real_out = real_out.view(-1,1).squeeze(1) style = self.fc(x) #style = torch.mean(style,1) # CrossEntropyLoss requires input be (N,C) return real_out,style )时,我的GAN完全不会产生噪声。

谢谢!

1 个答案:

答案 0 :(得分:2)

判别器中的批量归一化以梯度罚分破坏Wasserstein GAN。作者自己主张使用层归一化,但是在他们的论文(https://papers.nips.cc/paper/7159-improved-training-of-wasserstein-gans.pdf中显然用粗体写。很难说您的代码中是否还有其他错误,但是我敦促您通读DCGAN和Wasserstein GAN论文,并真正记下超参数。弄错它们确实会破坏GAN的性能,而执行超参数搜索会很快变得昂贵。

通过转置卷积在输出图像中产生楼梯伪像的方式。请改用图像大小调整。有关该现象的深入说明,我可以推荐以下资源(https://distill.pub/2016/deconv-checkerboard/)。