如何在Keras中编译模型后动态冻结权重?

时间:2017-07-17 21:47:41

标签: python tensorflow neural-network keras theano

我想在Keras训练一个GAN。我的最终目标是BEGAN,但我从最简单的目标开始。理解如何正确冻结权重是必要的,这就是我正在努力解决的问题。

在发电机训练期间,可能不会更新鉴别器权重。我想交替地冻结解冻鉴别器来交替训练生成器和鉴别器。问题是在 discriminator 模型上将 trainable 参数设置为false,甚至在其'权重上'不会停止模型训练(以及更新权重)。另一方面,当我将 trainable 设置为False后编译模型时,权重变为 unfreezable 。我无法在每次迭代后编译模型,因为这会否定整个训练的想法。

由于这个问题,许多Keras实现似乎被窃听或者它们起作用,因为在旧版本或其他东西中有一些非直观的技巧。

3 个答案:

答案 0 :(得分:8)

我几个月前尝试过这个示例代码并且它有效: https://github.com/fchollet/keras/blob/master/examples/mnist_acgan.py

它不是最简单的GAN形式,但就我记忆而言,删除分类丢失并将模型转换为GAN并不太难。

您不需要打开/关闭鉴别器的可训练属性并重新编译。只需创建和编译两个模型对象,一个包含trainable=True(代码中为discriminator),另一个包含trainable=False(代码中为combined)。

当您更新鉴别器时,请致电discriminator.train_on_batch()。当您更新生成器时,请致电combined.train_on_batch()

答案 1 :(得分:0)

您可以使用tf.stop_gradient有条件地冻结权重吗?

答案 2 :(得分:0)

也许你的对抗网(生成器加鉴别器)是在' Model'中写的。 但是,即使你设置 d.trainable = False ,独立的d net也设置为不可训练,但整个对抗网中的d仍然是可训练的。

你可以在设置 d.trainable = False 之后使用d_on_g.summary(),你会知道我的意思(注意可训练的变量)。