由于它不是“检查点”,因此恢复TensorFlow 2.0 Training的崩溃恢复的标准方法是什么?

时间:2019-12-28 22:10:51

标签: tensorflow2.0 checkpointing

要在崩溃后恢复训练,不仅必须还原模型,还必须还原进入model.fit(...)进程状态的所有对象和参数。

在我费力分叉keras代码以实现fitting对象时,例如,包括训练数据,我想知道崩溃的标准方法(如果有) -recovery恢复从上次停止的TensorFlow 2.0培训。

还是有人真的在TensorFlow对象模型中填补了这个明显的空白?

1 个答案:

答案 0 :(得分:1)

检查tf.keras.Model.fit()进程的规范方法是ModelCheckpoint回调。

用法类似于:

mode.fit(..., callbacks=[tf.keras.callbacks.ModelCheckpoint(checkpoint_dir)]

默认情况下,在每个训练时期结束时生成的已保存检查点不仅包括模型的架构和权重值,还包括训练状态。如果您有兴趣,可以研究其源代码here。保存的训练状态包括

  • 优化器配置
  • 优化器的权重变量值(对于状态优化器,例如Adam)
  • 损失和指标配置

这些内容是否涵盖您所考虑的所有训练状态?