尝试在Pytorch中实现GAN,我得到的结果是,生成器不学习任何东西(或学得不好),而鉴别器pefrorms很好(大约95%正确)。我想反向传播设置有问题。 这个项目很大,所以我没有充分发布它,只是培训中的重要位置:
loss = torch.nn.CrossEntropyLoss()
...
for epoch in range(epochs):
for start_index in range(0,len(x_train), batch_size):
optimizer.zero_grad()
x_batch = x_train[start_index : start_index+batch_size]
y_batch = y_train[start_index : start_index+batch_size]
output = nnet.forward(x_batch)
real_loss_value = loss(output, y_batch)
x_gen, y_gen_false_real = ngen.rnd_batch(x_batch.size(0))
x_gen = x_gen.view(-1,1,28,28)
y_gen_true_fake = y_gen_false_real + 10
gen_output = nnet.forward(x_gen)
gen_optimizer.zero_grad()
gen_output = nnet.forward(x_gen)
gen_success_loss = loss(gen_output, y_gen_false_real)
gen_success_loss.backward()
gen_optimizer.step()
# Measure discriminator's ability to classify real from generated samples
# if fake recognized, the output will be 10-19
gen_output = nnet.forward(x_gen.detach())
fake_loss_value = loss(gen_output, y_gen_true_fake)
d_loss = (real_loss_value + fake_loss_value) / 2
d_loss.backward()
optimizer.step()
optimizer.zero_grad()
这与示例https://github.com/eriklindernoren/PyTorch-GAN中的教程不同 但我想以下应该工作: 鉴别器输出20个标志:第一个0-9代表实数,最后10-19个被识别为伪造者。相应的输出在行
y_gen_true_fake = y_gen_false_real + 10
在损失d_loss = (real_loss_value + fake_loss_value) / 2
的情况下,即使在1个纪元之后,鉴别器也仍会罚款,但是gen_success_loss = loss(gen_output, y_gen_false_real)
的生成器什么也不倾斜,并且只会产生噪声。我猜反向传播调用中出现问题,我不太了解这多个反向传播调用。你能帮我吗?