甘斯发电机的损耗函数

时间:2018-07-12 10:12:51

标签: machine-learning computer-vision artificial-intelligence conv-neural-network generative-adversarial-network

我已经深入研究了gans并在pytorch中实现了它,现在,当我浏览该网站Mathematics behing Gans时,我正在研究gans背后的核心统计信息 它说

  

“损失(G)=-损失(D),请注意,我们将发电机成本定义为判别器成本的负数。这是因为我们没有明确的方法来评估发电机成本。”

但是在执行gan时,我们将生成器的损耗定义为:

  

Bintropy生成器和Real标签所产生图像的鉴别器输出之间的交叉熵损失(如原始论文所述)和以下代码(由我实施和测试)

    # train generator
    z_ = to.randn(minibatch,100 ).view(-1, 100, 1, 1)
    z_ = Variable(z_.cuda())
    gen_images = generator(z_)

    D_fake_decisions = discriminator(gen_images).squeeze()
    G_loss = criterion(D_fake_decisions,real_labels)

    discriminator.zero_grad()
    generator.zero_grad()
    G_loss.backward()
    opt_Gen.step()

请向我解释一下两者之间的区别以及正确的区别

代码链接:https://github.com/mabdullahrafique/Gan_with_Pytorch/blob/master/DCGan_mnist.ipynb

谢谢

1 个答案:

答案 0 :(得分:1)

判别器的工作是执行二进制分类以检测真实和伪造之间的距离,因此损失函数为二进制交叉熵。

发生器所做的是从噪声到真实数据的密度估计,并将其馈送到鉴别器以使其蒙蔽。

设计中采用的方法是将其建模为MinMax游戏。现在让我们看一下成本函数:

Source: https://towardsdatascience.com/generative-adversarial-networks-history-and-overview-7effbb713545

有人将其解释为:

  

J(D)中的第一项表示将实际数据提供给鉴别器,鉴别器将要最大化预测一个的对数概率,从而表明数据是真实的。第二项表示由G生成的样本。在这里,鉴别器将要最大化预测零的对数概率,这表明数据是伪造的。另一方面,生成器试图使鉴别器正确的对数概率最小。解决该问题的方法是博弈的平衡点,这是鉴别器损失的鞍点。

由于判别器试图将生成器样本的概率最大化为零,因此生成器的工作将最大化一个。这等效于使Generator的成本函数为负交叉熵,此时J(D)中的第一项将是恒定的。

来源:https://towardsdatascience.com/generative-adversarial-networks-history-and-overview-7effbb713545