Keras-没有停止和继续训练的好方法吗?

时间:2020-09-07 10:39:32

标签: python tensorflow keras tensorflow2.0 tf.keras

经过大量研究,似乎没有一种很好的方法来正确地停止和恢复使用Tensorflow 2 / Keras模型的训练。无论您是使用model.fit() 还是使用自定义训练循环的,都是如此。

训练时似乎有2种支持的方法来保存模型:

  1. 使用model.save_weights()save_weights_only=Truetf.keras.callbacks.ModelCheckpoint,仅保存模型的权重。在我所见过的大多数示例中,这似乎都是首选方法,但是它存在许多主要问题:

    • 优化器状态未保存,这意味着恢复训练将不正确。
    • 学习率时间表已重置-对于某些型号来说可能是灾难性的。
    • Tensorboard日志返回到第0步-除非实施了复杂的解决方法,否则基本上没有任何日志记录。
  2. 使用model.save()save_weights_only=False保存整个模型,优化器等。优化器状态已保存(良好),但仍然存在以下问题:

    • Tensorboard日志仍然返回到步骤0
    • 学习率计划仍在重置(!!!)
    • 无法使用自定义指标。
    • 当使用自定义训练循环时,这根本不起作用-自定义训练循环使用非编译模型,并且似乎不支持保存/加载非编译模型。

我发现最好的解决方法是使用自定义训练循环,手动保存步骤。这可以修复张量板日志记录,并且可以通过执行类似keras.backend.set_value(model.optimizer.iterations, step)的操作来固定学习速率计划。但是,由于无法进行完整的模型保存,因此不会保留优化器状态。我看不到至少独立地保存优化器状态的方法,至少没有很多工作。搞乱了我的LR日程安排也感到很混乱。

我错过了什么吗?那里的人们如何使用此API进行保存/恢复?

2 个答案:

答案 0 :(得分:3)

您是对的,没有内置的可恢复性支持-这正是促使我创建DeepTrain的原因。就像TensorFlow / Keras的Pytorch Lightning(在不同方面更胜一筹)。

为什么要使用另一个库?我们还不够吗?如果有的话,我不会建造它。 DeepTrain针对“保姆法”进行了量身定制的培训:训练较少的模型,但进行彻底的训练。密切监视每个阶段,以诊断出问题所在和解决方法。

灵感来自我自己的用途;我会在一个较长的时期内看到“验证峰值”,并且无法停下来,因为它会重新启动该时期或破坏火车循环。忘了知道自己适合哪一批,或者剩下多少。

与Pytorch Lightning相比如何?凭借独特的火车调试实用程序,具有出色的可重复性和自省性-但Lightning在其他方面表现更好。我在工作中有一个完整的列表比较,将在一周内发布。

Pytorch支持即将到来?如果我说服闪电开发团队弥补相对于DeepTrain的缺点,那就不要-否则可能。同时,您可以浏览Examples的画廊。


最小示例

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from deeptrain import TrainGenerator, DataGenerator

ipt = Input((16,))
out = Dense(10, 'softmax')(ipt)
model = Model(ipt, out)
model.compile('adam', 'categorical_crossentropy')

dg  = DataGenerator(data_path="data/train", labels_path="data/train/labels.npy")
vdg = DataGenerator(data_path="data/val",   labels_path="data/val/labels.npy")
tg  = TrainGenerator(model, dg, vdg, epochs=3, logs_dir="logs/")

tg.train()

您可以随时KeyboardInterrupt,检查模型,训练状态,数据生成器-并恢复。

答案 1 :(得分:2)

已为 tf.keras.callbacks.experimental.BackupAndRestore 添加了用于从中断中恢复训练的

tensorflow>=2.3 API。根据我的经验,它非常有效。

参考: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/experimental/BackupAndRestore