pytorch中的GAN:鉴别器获胜,设置错误

时间:2019-07-08 01:24:41

标签: python-3.x generator pytorch discriminator

尝试在Pytorch中实现GAN,我得到的结果是,生成器不学习任何东西(或学得不好),而鉴别器pefrorms很好(大约95%正确)。我想反向传播设置有问题。 这个项目很大,所以我没有充分发布它,只是培训中的重要位置:

loss = torch.nn.CrossEntropyLoss()

...

for epoch in range(epochs):

for start_index in range(0,len(x_train), batch_size):


optimizer.zero_grad() 

    x_batch = x_train[start_index : start_index+batch_size]

    y_batch = y_train[start_index : start_index+batch_size]

    output = nnet.forward(x_batch)

    real_loss_value = loss(output, y_batch)

    x_gen, y_gen_false_real = ngen.rnd_batch(x_batch.size(0))






    x_gen = x_gen.view(-1,1,28,28)  

    y_gen_true_fake = y_gen_false_real + 10

    gen_output = nnet.forward(x_gen)


    gen_optimizer.zero_grad()

    gen_output = nnet.forward(x_gen)

    gen_success_loss =   loss(gen_output, y_gen_false_real)

    gen_success_loss.backward()      


    gen_optimizer.step()        


    # Measure discriminator's ability to classify real from generated samples
    # if fake recognized, the output will be 10-19

    gen_output = nnet.forward(x_gen.detach())

    fake_loss_value = loss(gen_output, y_gen_true_fake)

    d_loss = (real_loss_value + fake_loss_value) / 2

    d_loss.backward()

    optimizer.step()

    optimizer.zero_grad() 

这与示例https://github.com/eriklindernoren/PyTorch-GAN中的教程不同 但我想以下应该工作: 鉴别器输出20个标志:第一个0-9代表实数,最后10-19个被识别为伪造者。相应的输出在行

y_gen_true_fake = y_gen_false_real + 10

在损失d_loss = (real_loss_value + fake_loss_value) / 2的情况下,即使在1个纪元之后,鉴别器也仍会罚款,但是gen_success_loss = loss(gen_output, y_gen_false_real)的生成器什么也不倾斜,并且只会产生噪声。我猜反向传播调用中出现问题,我不太了解这多个反向传播调用。你能帮我吗?

0 个答案:

没有答案