我正在研究GANS,所以我需要在下班后保存模型。然后,我必须重新训练先前保存的模型,并将其保存在原来的位置。我保存了这三个模型,以便以后继续训练。
Discriminator Model.h5
Generator Model.h5
Generator-on-Discriminator Model.h5
对于这些模型,我正在使用perceptual loss
和Wasserstein loss
。但是,当我load_model
重新训练保存的模型时,遇到以下错误。
Unknown loss function:wasserstein_loss
我也尝试过Discriminator.compile(loss=Wasserstein loss)
,但这仍然不能解决我的问题。谁能指导我一下,并告诉我使用train_on_batch()重新训练保存的模型的可能性。
答案 0 :(得分:3)
由我自己解决
在加载模型时定义custom_objects={'wassertein_loss':wassertein_loss}
以及路径解决了我的问题。即
Discriminator=load_model(model_path, custom_objects={'wassertein_loss':wassertein_loss} )