用于域翻译的条件GAN

时间:2018-09-16 10:41:21

标签: python keras generative-adversarial-network

我正在训练GAN从两个不同的图像域(源S和目标T)执行样式转换。由于我有可用的班级信息,因此我拥有一个额外的Q网络(GD除外),该网络可以测量针对目标域及其标签生成的图像的分类结果(LeNet网络) ),并使用D将错误传播到生成器。从系统的收敛中,我注意到D始终从8开始(D网络的损失函数误差),并一直下降到4.5,而G损失函数误差为从1开始迅速下降到0.2。 here可以找到我正在使用的DG的损失函数,而Q网络的损失函数是分类交叉熵。迭代中的误差图为:

enter image description here

D和G的损失函数是:

def discriminator_loss(y_true,y_pred):
      BATCH_SIZE=10
      return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.concatenate([K.ones_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])),K.zeros_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])) ]) ), axis=-1)

def discriminator_on_generator_loss(y_true,y_pred):
     BATCH_SIZE=10
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.ones_like(K.flatten(y_pred))), axis=-1)

def generator_l1_loss(y_true,y_pred):
     BATCH_SIZE=10
     return K.mean(K.abs(K.flatten(y_pred) - K.flatten(y_true)), axis=-1)

D的误差函数总是那么高有意义吗? DG错误的解释是什么? D的损失在开始时应该很小,而在迭代之后应该增加吗?用损失阈值限制D胜过G是个好主意吗?最后,在训练过程中,根据验证集上的损失函数而不是根据我所使用的训练集来计算误差是否有意义? (而不是直接使用train_on_batch使用fit,然后对测试集进行评估)。

编辑:

对于损失,我认为discriminatordiscriminator_on_generator的损失是GAN的正常损失函数,对吗?

1 个答案:

答案 0 :(得分:0)

让G是生成器,D是鉴别器。最初,D和G都未经训练。现在,让我们假设D的学习速度比G快。因此,过了一会儿,G可以区分从真实数据分布中采样的样本和从生成器中采样的样本。最后,G赶上来,学习模拟真实的数据分布。现在,G无法再区分从真实数据分布中采样的样本和从生成器中采样的样本了。

Combined G D losses

因此,我们首先从D和G的高损失开始(区域I)。然后,D的损耗下降的速度快于G的损耗(区域I至II)。随着G的损失继续减少,D的损失增加(区域II)。最后,损失都达到平衡值,从而完成了训练(区域III)。

D loss G loss