这是WGAN-GP的损失功能
gen_sample = model.generator(input_gen)
disc_real = model.discriminator(real_image, reuse=False)
disc_fake = model.discriminator(gen_sample, reuse=True)
disc_concat = tf.concat([disc_real, disc_fake], axis=0)
# Gradient penalty
alpha = tf.random_uniform(
shape=[BATCH_SIZE, 1, 1, 1],
minval=0.,
maxval=1.)
differences = gen_sample - real_image
interpolates = real_image + (alpha * differences)
gradients = tf.gradients(model.discriminator(interpolates, reuse=True), [interpolates])[0] # why [0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
d_loss_real = tf.reduce_mean(disc_real)
d_loss_fake = tf.reduce_mean(disc_fake)
disc_loss = -(d_loss_real - d_loss_fake) + LAMBDA * gradient_penalty
gen_loss = - d_loss_fake
发电机损耗在振荡,其值是如此之大。 我的问题是: 发电机损耗正常还是异常?
答案 0 :(得分:1)
要注意的一件事是您的梯度罚分计算是错误的。以下行:
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
实际上应该是:
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2,3]))
您在第一个轴上进行缩小,但是渐变基于Alpha值所示的图像,因此您必须在轴[1,2,3]
上进行缩小。
您的代码中的另一个错误是发电机损耗为:
gen_loss = d_loss_real - d_loss_fake
对于梯度计算,这没有什么区别,因为生成器的参数仅包含在d_loss_fake中。但是,对于发电机损失的价值,这使世界变得与众不同,这就是造成这种损失的原因。
在一天结束时,您应该查看自己关心的实际性能指标,以便确定GAN的质量,例如初始分数或Fréchet初始距离(FID),因为甄别器和生成器的损失仅是轻微的描述性的。