如何使用生成对抗网络中的鉴别器输出训练生成器

时间:2017-06-23 19:45:46

标签: machine-learning neural-network deep-learning generative

最近我了解了Generative Adversarial Networks

为了训练发电机,我对它的学习方式感到困惑。 Here是GAN的实施:

`# train generator
            z = Variable(xp.random.uniform(-1, 1, (batchsize, nz), dtype=np.float32))
            x = gen(z)
            yl = dis(x)
            L_gen = F.softmax_cross_entropy(yl, Variable(xp.zeros(batchsize, dtype=np.int32)))
            L_dis = F.softmax_cross_entropy(yl, Variable(xp.ones(batchsize, dtype=np.int32)))

        # train discriminator

        x2 = Variable(cuda.to_gpu(x2))
        yl2 = dis(x2)
        L_dis += F.softmax_cross_entropy(yl2, Variable(xp.zeros(batchsize, dtype=np.int32)))

        #print "forward done"

        o_gen.zero_grads()
        L_gen.backward()
        o_gen.update()

        o_dis.zero_grads()
        L_dis.backward()
        o_dis.update()`

所以它计算了发电机的损耗,正如文中提到的那样。 但是,它根据Discriminator输出调用Generator向后功能。鉴别器输出只是一个数字(不是数组)。

但是我们知道,一般来说,为了训练网络,我们在最后一层计算一个损失函数(最后一层输出和实际输出之间的损失)然后我们计算梯度。因此,例如,如果输出为64 * 64,那么我们将它与64 * 64图像进行比较,然后计算损耗并进行反向传播。

然而,在我在Generative Adversarial Networks中看到的代码中,我看到它们从鉴别器输出(这只是一个数字)计算发生器的损耗,然后它们调用Generator的反向传播。发生器最后一层是例如64 * 64像素,但鉴别器丢失是1 * 1(这与通常的网络不同)所以我不明白它是如何导致生成器被学习和训练的?

我想如果我们连接两个网络(连接Generator和Discriminator)然后调用反向传播但只更新Generators参数,它就有意义了,它应该可行。但我在代码中看到的完全不同。

所以我问这是怎么回事?

谢谢

1 个答案:

答案 0 :(得分:0)

您说“但是,它会基于Discriminator输出调用Generator的后退功能。鉴别器输出只是一个数字(不是数组),而损失始终是一个标量值。当我们计算两个图像的均方误差时,它也是一个标量值。

L_adversarial = E [log(D(x())] + E [log(1-D(G(z))]

x来自真实数据分发

z是由Generator转换的潜在数据分布

回到您的实际问题,鉴别器网络在最后一层具有S型激活功能,这意味着它的输出范围为[0,1]。鉴别器试图通过最大化损失函数中添加的两个项来最大化此损失。第一项的最大值为0,在D(x)为1时出现,第二项的最大值也为0,在1-D(G(z))为1时出现,这意味着D(G(z))为0。因此,鉴别器尝试对我的损失函数进行最大化,从而对它进行二进制分类,从而在输入x(真实数据)时尝试输出1,而在输入G(z)(生成的伪数据)时尝试输出0。 但是,生成器试图通过将生成的伪样本与真实样本相似来使这种损失最小化,换句话说,它试图欺骗鉴别器。随着时间的流逝,生成器和鉴别器会越来越好。这就是GAN的直觉。

代码在pytorch中

bce_loss = nn.BCELoss() #bce_loss = -ylog(y_hat)-(1-y)log(1-y_hat)[similar to L_adversarial]

Discriminator = ..... #some network   
Generator = ..... #some network

optimizer_generator = ....... #some optimizer for generator network    
optimizer_discriminator = ....... #some optimizer for discriminator network       

z = ...... #some latent data distribution that is transformed by the generator
real = ..... #real data distribution

#####################
#Update Discriminator
#####################
fake = Generator(z)
fake_prediction = Discriminator(fake)
real_prediction = Discriminator(real)
discriminator_loss = bce_loss(fake_prediction,torch.zeros(batch_size))+bce_loss(real_prediction,torch.ones(batch_size))
discriminator_loss.backward()
optimizer_discriminator.step()

#################
#Update Generator
#################
fake = Generator(z)
fake_prediction = Discriminator(fake)
generator_loss = bce_loss(fake_prediction,torch.ones(batch_size))
generator_loss.backward()
optimizer_generator.step()