我是否必须每批重新编译Gan,以防止歧视者学习?

时间:2019-01-05 14:25:58

标签: keras generative-adversarial-network

我有一个像这样的甘人

generator = Model(g_in, g_out)
generator.compile(...)

discriminator = Model(d_in, d_out)
discriminator.trainable = True
discriminator.compile(..)

discriminator.trainable = False

gan = Model(inputs=.., outputs=..)
gan.compile(..)

#iterate over epochs and batches, without compiling

它学习并给出可接受的输出。但是我得到警告:

“ keras \ engine \ training.py:490:用户警告:可训练砝码与收集的可训练砝码之间存在差异,您是否设置了model.trainable却没有在之后调用model.compile?   “可训练重量与收集的可训练重量之间的差异””

如果我重新编译鉴别器并每批gan,警告就会消失,但是一次迭代会花费更长的时间,并且训练速度会变慢。

for epoch:
  for batch:

    fakes=generator.predict_on_batch(batch)

    discriminator.trainable = True
    discriminator.compile(..)

    discriminator.train_on_batch(batch, ..)
    discriminator.train_on_batch(fakes, ..)

    discriminator.trainable = False
    discriminator.compile(..)
    gan.compile(..)

    gan.train_on_batch(batch,..)

其中哪一个是正确的?

1 个答案:

答案 0 :(得分:0)

这是预料之中的,不需要重新编译每个批处理。 Keras对此有一个开放的错误:https://github.com/keras-team/keras/issues/8585

那里的答复有一些如何通过警告的示例,我在这里不再重复。如果您不确定模型的具体细节,还有一封回复可以为您提供很好的建议,以帮助您验证自己是否正在训练自己的训练内容:https://github.com/keras-team/keras/issues/8585#issuecomment-385729276