如何使用ModelCheckpoint回调加载我的Tensorflow模型?

时间:2020-02-12 07:05:28

标签: deep-learning callback conv-neural-network tensorflow2.0 tf.keras

我已经训练了一个模型,并使用ModelCheckpoint保存了权重:

checkpoint_callback = ModelCheckpoint(
    filepath = checkpoint_prefix,
    save_weights_only = True,
    save_freq = 'epoch')

在我的模型训练期间的夜间,电源关闭了一段时间,计算机也关闭了。现在,我打开了Jupyter笔记本,我想从一开始就不进行培训就加载模型。我应该怎么做而不重新编译而只使用检查点呢? 我也有张量板回调:

tensorboard_callback = TensorBoard(
    log_dir = 'tensorboard_logs\\'+ model_name,
    histogram_freq = 5,
    write_graph = True,
    update_freq = 'epoch')

1 个答案:

答案 0 :(得分:1)

由于只保存了模型的权重,因此需要重建图形,然后在其上加载最后一个检查点权重。

因此,您必须重新创建模型并进行编译。
对于下一次,如果您要保存完整的模型,而不必每次加载时都再次编译,请将save_weights_only设置为False
它允许您使用keras.models.load_model()加载模型并在之后直接拟合。

model = Sequential()
model.add() 
...
model.compile()

然后加载您的体重:

model.load_weights(checkpoint_prefix)

然后可以正常使用它:

model.fit( ... )