tf.keras如何保存ModelCheckPoint对象

时间:2019-12-09 19:12:23

标签: keras callback pickle google-colaboratory tf.keras

ModelCheckpoint可用于基于特定的受监视指标保存最佳模型。因此,它显然具有有关存储在其对象中的最佳度量的信息。例如,如果您在google colab上进行培训,则您的实例可能会在没有警告的情况下被杀死,并且经过长时间的培训,您可能会丢失此信息。

我试图腌制ModelCheckpoint对象,但得到:

TypeError: can't pickle _thread.lock objects  

这样一来,当我拿回笔记本时,我可以重复使用同一对象。有什么好方法吗?您可以尝试通过以下方式进行复制:

chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

with open('chkpt_cb.pickle', 'w') as f:
  pickle.dump(chkpt_cb, f, protocol=pickle.HIGHEST_PROTOCOL)

2 个答案:

答案 0 :(得分:0)

我认为您可能误解了ModelCheckpoint对象的预期用法。 callback是在特定阶段的训练期间定期被调用的。特别是,在每个时期之后(如果您保留默认的period=1),都会调用ModelCheckpoint回调并将模型保存到磁盘中,该文件名是您为filepath指定的文件名。以与here相同的方式保存模型。然后,如果您想稍后加载该模型,则可以执行类似的操作

from keras.models import load_model
model = load_model('my_model.h5')

关于SO的其他答案为从保存的模型继续进行训练提供了很好的指导和示例,例如:Loading a trained Keras model and continue training。重要的是,保存的H5文件存储了继续训练所需的有关模型的所有信息。

按照Keras documentation中的建议,您不应使用pickle序列化模型。只需使用您的“ fit”功能注册ModelCheckpoint回调:

chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)
model.fit(x_train, y_train,
          epochs=100,
          steps_per_epoch=5000,
          callbacks=[chkpt_cb])

您的模型将保存在一个具有您自己名字的H5文件中,其中会自动为您格式化历元号和损失值。例如,您为第5个时期保存的文件,其损失为0.0023,看起来像model.05-.0023.h5,并且由于您设置了save_best_only=True,因此只有在您的损失好于先前保存的文件的情况下,才会保存模型不要用一堆不需要的模型文件污染目录。

答案 1 :(得分:0)

如果不要腌制回调对象(由于线程问题并且不建议使用),我可以改为腌制此:

best = chkpt_cb.best

这将存储回调已看到的最佳监视指标,它是一个浮动值,您可以在其下进行腌制并重新加载,然后执行以下操作:

chkpt_cb.best = best   # if chkpt_cb is a brand new object you create when colab killed your session. 

这是我自己的设置:

# All paths should be on Google Drive, I omitted it here for simplicity.

chkpt_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

if os.path.exists('chkpt_cb.best.pickle'):
  with open('chkpt_cb.best.pickle', 'rb') as f:
    best = pickle.load(f)
    chkpt_cb.best = best

def save_chkpt_cb():
  with open('chkpt_cb.best.pickle', 'wb') as f:
    pickle.dump(chkpt_cb.best, f, protocol=pickle.HIGHEST_PROTOCOL)

save_chkpt_cb_callback = tf.keras.callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: save_chkpt_cb()
)

history = model.fit_generator(generator=train_data_gen,
                          validation_data=dev_data_gen,
                          epochs=5,
                          callbacks=[chkpt_cb, save_chkpt_cb_callback])

因此,即使您的colab会话被终止,您仍然可以检索最新的最佳指标并向您的新实例通知该指标,并照常继续培训。当您重新编译有状态的优化器时,这特别有帮助,可能会导致损耗/度量值下降,并且不想在最初的几个时期保存这些模型。