冻结GAN中的鉴别器层

时间:2020-02-02 19:40:22

标签: tensorflow keras neural-network generative-adversarial-network

我正在尝试在Keras中实现SimGAN,但我认为问题通常与GAN和冻结层有关。

据我了解,我需要三种模型:

  1. 优化程序,它处理合成图像以使其更加逼真。 (在其他GAN架构中,这可能是 generator ,它会接受随机噪声。)

  2. 鉴别器,用于处理图像并将其分类为合成图像或真实图像。

  3. 一个组合模型,该模型将合成图像通过精炼机,然后将精炼图像传递到鉴别器。

在组合模型中,我们希望冻结鉴别器层 ,以便在训练组合模型时,我们仅更新精简器层。

另外,我们训练仅鉴别器模型,显然不应冻结层。鉴别器模型和组合模型的各层应共享权重,以便它们都可以更新。

这是我到目前为止所拥有的:

refiner_model = make_refiner_model(input_shape=(img_height, img_width, img_channels))
discriminator_model = make_discriminator_model(input_shape=refiner_model.output_shape[1:])

# create combined model with frozen discriminator layers
synthetic_image_tensor = layers.Input(refiner_model.input_shape[1:])
refiner_model_output = refiner_model(synthetic_image_tensor)
combined_output = discriminator_model(refiner_model_output)

combined_model = models.Model(
    inputs=synthetic_image_tensor,
    outputs=[refiner_model_output, combined_output],
    name='combined'
)

如何在组合模型中冻结标识符层,而在仅标识符的模型中冻结


Keras FAQ明确建议以下内容:

refiner_model.compile(...)
discriminator_model.compile(...)

discriminator_model.trainable = False
combined_model.compile(...)

但是当我打印出discriminator_model.summary()时,参数数量增加了一倍吗?

Total params: 151,812
Trainable params: 75,906
Non-trainable params: 75,906

然后我将得到warnings about changing .trainable without recompiling,最终以this error失败。

0 个答案:

没有答案