GAN中的损失功能

时间:2018-04-23 19:21:54

标签: python tensorflow machine-learning neural-network generative

我试图构建一个简单的mnist GAN,并且需要更少的说法,它没有用。我经常搜索并修复了大部分代码。虽然我无法真正了解损失函数的工作原理。

这就是我所做的:

loss_d = -tf.reduce_mean(tf.log(discriminator(real_data))) # maximise
loss_g = -tf.reduce_mean(tf.log(discriminator(generator(noise_input), trainable = False))) # maxmize cuz d(g) instead of 1 - d(g)
loss = loss_d + loss_g

train_d = tf.train.AdamOptimizer(learning_rate).minimize(loss_d)
train_g = tf.train.AdamOptimizer(learning_rate).minimize(loss_g)

我得到-0.0作为我的损失值。你能解释一下如何处理GAN中的损失函数吗?

2 个答案:

答案 0 :(得分:2)

似乎您试图将发生器和鉴别器的损失加起来,这是完全错误的! 由于鉴别器同时训练了真实数据和生成的数据,因此您必须创建两个截然不同的损失,一个损失是实数据,另一个损失是传递到鉴别器网络中的噪声数据(生成的)。

尝试如下更改代码:

1)

loss_d_real = -tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminator(real_data),labels= tf.ones_like(discriminator(real_data))))

2)

loss_d_fake=-tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminator(noise_input),labels= tf.zeros_like(discriminator(real_data))))

则鉴别器损失将等于= loss_d_real + loss_d_fake。 现在为您的发电机造成损失:

3)

loss_g= tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminator(genereted_samples), labels=tf.ones_like(genereted_samples)))

答案 1 :(得分:1)

Maryam似乎已经确定了您的虚假损耗值的原因(即,将生成器损耗和鉴别器损耗相加)。只是想补充一点,您可能应该选择判别器的随机梯度下降优化器来代替亚当-这样做为玩minimax游戏提供了网络融合的更强理论保证(来源:https://github.com/soumith/ganhacks)。