WGAN-GP火车损失惨重

时间:2018-11-21 13:58:30

标签: python tensorflow deep-learning

这是WGAN-GP的损失功能

gen_sample = model.generator(input_gen)
disc_real = model.discriminator(real_image, reuse=False)
disc_fake = model.discriminator(gen_sample, reuse=True)
disc_concat = tf.concat([disc_real, disc_fake], axis=0)
# Gradient penalty
alpha = tf.random_uniform(
    shape=[BATCH_SIZE, 1, 1, 1],
    minval=0.,
    maxval=1.)
differences = gen_sample - real_image
interpolates = real_image + (alpha * differences)
gradients = tf.gradients(model.discriminator(interpolates, reuse=True), [interpolates])[0]    # why [0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)

d_loss_real = tf.reduce_mean(disc_real)
d_loss_fake = tf.reduce_mean(disc_fake)

disc_loss = -(d_loss_real - d_loss_fake) + LAMBDA * gradient_penalty
gen_loss = - d_loss_fake

This is the training loss

发电机损耗在振荡,其值是如此之大。 我的问题是: 发电机损耗正常还是异常?

1 个答案:

答案 0 :(得分:1)

要注意的一件事是您的梯度罚分计算是错误的。以下行:

slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))

实际上应该是:

slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2,3]))

您在第一个轴上进行缩小,但是渐变基于Alpha值所示的图像,因此您必须在轴[1,2,3]上进行缩小。

您的代码中的另一个错误是发电机损耗为:

gen_loss = d_loss_real - d_loss_fake

对于梯度计算,这没有什么区别,因为生成器的参数仅包含在d_loss_fake中。但是,对于发电机损失的价值,这使世界变得与众不同,这就是造成这种损失的原因。

在一天结束时,您应该查看自己关心的实际性能指标,以便确定GAN的质量,例如初始分数或Fréchet初始距离(FID),因为甄别器和生成器的损失仅是轻微的描述性的。