保存并加载GAN模型以使用Keras进行继续训练

时间:2020-01-22 11:06:18

标签: tensorflow keras neural-network generative-adversarial-network

我正在尝试保存GAN模型,以便以后可以继续训练。

基本上,在训练循环之后,我将使用以下命令分别保存鉴别器和生成器:

discriminator.save("discriminatorTrained.h5")
generator.save("generatorTrained.h5")

然后,当我想继续训练时,我会像这样加载它们:

# Load Discriminator and Generator
discriminator = load_model('discriminatorTrained.h5')
generator = load_model('generatorTrained.h5')
discriminator.trainable = False

然后我用这样的加载的鉴别器和生成器制作一个新的GAN:

#Make new GAN from trained discriminator and generator
gan_input = Input(shape=(noise_dim,))
fake_image = generator(gan_input)
gan_output = discriminator(fake_image)
gan = Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)

然后从一开始就运行相同的训练脚本。

我没有收到任何错误消息,但是它似乎可以正常工作,但是,如果比较结果(例如保存和加载并继续训练10次),似乎生成器的效果不如我进行一次单次训练,次数最多为10个。

所以我怀疑我可能在这里丢失了一些东西,在此过程中是否丢失了一些培训信息,也许是在重新创建GAN模型时?

0 个答案:

没有答案