GAN 在训练过程中卡住了

时间:2021-05-17 13:52:25

标签: python deep-learning pytorch generative-adversarial-network

我正在尝试使用 MNIST 数据集构建 GAN,当我开始训练它时,我可以看到生成器和鉴别器都卡住了(即损失值不会改变,生成器输出相同白噪声无意义的图像)。我尝试更改鉴别器和生成器的损失函数、学习率、添加和删除层,但没有任何尝试奏效。附上我的代码:

def tensor_to_plt_im(im: torch.Tensor):
    return im.permute(1, 2, 0)


def d_loss(discriminator_generated_x, discriminator_true_x):
    return -0.5 * torch.mean(torch.log(discriminator_true_x + 1e-8)) \
           - 0.5 * torch.mean(torch.log(1 - discriminator_generated_x + 1e-8))


def gen_loss(discriminator_generated_x, discriminator_true_x):
    return -1 * d_loss(discriminator_generated_x, discriminator_true_x)


def gen_loss_non_saturating(discriminator_generated_x):
    return -0.5 * torch.mean(torch.log(discriminator_generated_x + 1e-8))


def d_loss_least_squares(discriminator_generated_x, discriminator_true_x):
    return 0.5 * torch.mean(torch.square(discriminator_true_x + 1e-8 - 1)) \
           + 0.5 * torch.mean(torch.square(discriminator_generated_x + 1e-8))


def gen_loss_least_squares(discriminator_generated_x):
    return 0.5 * torch.mean(torch.square(discriminator_generated_x + 1e-8 - 1))


class MNISTGen(nn.Module):
    def __init__(self, num_input_channels):
        super(MNISTGen, self).__init__()
        self.conv = nn.Sequential(
            # input is (num_input_channels*4) x 4 x 4
            nn.ConvTranspose2d(num_channels * 4, num_channels * 2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_channels * 2),
            # nn.ReLU(True),
            # size. (num_channels*2) x 7 x 7
            nn.ConvTranspose2d(num_channels * 2, num_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_channels),
            # nn.ReLU(True),
            # size. (num_channels) x 14 x 14
            nn.ConvTranspose2d(num_channels, num_input_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_input_channels),
            nn.Sigmoid()  # the final output image size. (num_input_channels) x 28 x 28
        )
        self.linear = nn.Sequential(
            # size. (latent_vec_size)
            nn.Linear(latent_vec_size, 64),
            nn.BatchNorm1d(64),
            # nn.LeakyReLU(0.2, inplace=True),
            # size. (64)
            nn.Linear(64, 256),
            nn.BatchNorm1d(256),
            # size. 256
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            # size. 512
            nn.Linear(512, (num_channels * 4) * 4 * 4),
            nn.BatchNorm1d((num_channels * 4) * 4 * 4),
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.view(-1, num_channels * 4, 4, 4)
        return self.conv(x)


class MNISTDisc(nn.Module):
    def __init__(self, num_input_channels):
        super(MNISTDisc, self).__init__()
        self.conv = nn.Sequential(
            # input is (num_input_channels) x 28 x 28
            nn.Conv2d(num_input_channels, num_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # size. (num_channels) x 14 x 14
            nn.Conv2d(num_channels, num_channels * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_channels * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # size. (num_channels*2) x 7 x 7
            nn.Conv2d(num_channels * 2, num_channels * 4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_channels * 4),
            nn.LeakyReLU(0.2, inplace=True),  # size. (num_channels*4) x 4 x 4
        )
        self.linear = nn.Sequential(
            # size. (num_channels_d*4) x 4 x 4
            nn.Linear((num_channels * 4) * 4 * 4, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # size. (512)
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            # size. 256
            nn.Linear(256, 64),
            nn.BatchNorm1d(64),
            # size. 64
            nn.Linear(64, 1),
            nn.Sigmoid()  # final classification
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, (num_channels * 4) * 4 * 4)  # reshape input so that it fits fc layer
        return self.linear(x)


def train(g_net: nn.Module, d_net: nn.Module):
    G_losses, D_losses, iters, gen_lst = [], [], 0, []
    optimizerD = optim.Adam(d_net.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(g_net.parameters(), lr=lr, betas=(beta1, 0.999))
    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # Train with all-real batch
            d_net.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            # Forward pass real batch through D
            real_output = d_net(real_cpu).view(-1)
            # errD_real.backward()
            D_x = real_output.mean().item()

            # Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, latent_vec_size, device=device)
            # Generate fake image batch with G
            fake = g_net(noise)
            # Classify all fake batch with D
            fake_output = d_net(fake.detach()).view(-1)
            # Calculate the gradients for this batch, accumulated (summed) with previous gradients
            loss = d_loss(fake_output, real_output)
            loss.backward()
            # errD_fake.backward()
            D_G_z1 = fake_output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            errD = loss
            # Update D
            optimizerD.step()

            # perform a generator iteration every 'discriminator_iterations' steps of the discriminator
            if i % discriminator_iterations == 0:
                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################
                g_net.zero_grad()
                # Since we just updated D, perform another forward pass of all-fake batch through D
                fake_output = d_net(fake).view(-1)
                # Calculate G's loss based on this output
                # errG = criterion(fake_output, label)
                errG = gen_loss_non_saturating(fake_output)
                # Calculate gradients for G
                errG.backward()
                D_G_z2 = fake_output.mean().item()
                # Update G
                optimizerG.step()

                # Output training stats
                if i % 50 == 0:
                    print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                          % (epoch, num_epochs, i, len(dataloader),
                             errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
                G_losses.append(errG.item())

            D_losses.append(errD.item())  # Save Losses for plotting later

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = g_net(fixed_noise).detach().cpu()
                gen_lst.append(vutils.make_grid(fake))
                plt.imshow(tensor_to_plt_im(gen_lst[-1]))
                plt.show()
            iters += 1


if __name__ == '__main__':
    dataloader = generate_mnist_data_set()
    generator = MNISTGen(num_input_channels=1)
    discriminator = MNISTDisc(num_input_channels=1)

    generator.apply(weights_init), discriminator.apply(weights_init)

    # Create batch of latent vectors that we will use to visualize
    #  the progression of the generator
    fixed_noise = torch.randn(image_size, latent_vec_size, device=device)
    train(generator, discriminator)

我将不胜感激任何建议、评论和一般指导,因为就像我说的那样,我完全被困住了。

提前致谢

1 个答案:

答案 0 :(得分:0)

<块引用>

生成器和鉴别器都卡住了(即损失值不会改变

这就是 GAN 的训练过程。请注意,这两个神经网络相互“学习”。运行训练而不注意损失值并在一段时间后检查结果。

enter image description here

请注意,根据问题和硬件,培训可能需要数小时甚至数天。