在Keras中重新训练使用train_on_batch()进行训练的已保存模型

时间:2019-01-10 06:13:42

标签: python-3.x tensorflow keras deep-learning generative-adversarial-network

我正在研究GANS,所以我需要在下班后保存模型。然后,我必须重新训练先前保存的模型,并将其保存在原来的位置。我保存了这三个模型,以便以后继续训练。

Discriminator Model.h5
Generator Model.h5
Generator-on-Discriminator Model.h5

对于这些模型,我正在使用perceptual lossWasserstein loss。但是,当我load_model重新训练保存的模型时,遇到以下错误。

Unknown loss function:wasserstein_loss

我也尝试过Discriminator.compile(loss=Wasserstein loss),但这仍然不能解决我的问题。谁能指导我一下,并告诉我使用train_on_batch()重新训练保存的模型的可能性。

1 个答案:

答案 0 :(得分:3)

由我自己解决

在加载模型时定义custom_objects={'wassertein_loss':wassertein_loss}以及路径解决了我的问题。即

Discriminator=load_model(model_path, custom_objects={'wassertein_loss':wassertein_loss} )
相关问题