保存模型,加载模型并继续培训的最佳方法是什么?

时间:2019-10-22 10:15:18

标签: python-3.x keras tensorflow2.0

我在这里经历了许多答案,但发现它们部分正确或无用。

我有一个需要训练的keras / tensorflow模型。在这次培训中,我的模型将是

  • 训练有素的时代
  • 在每个时期保存
  • 加载其他任何时间以继续培训

我应该怎么做?

2 个答案:

答案 0 :(得分:0)

根据需要,可以使用model.save()或使用ModelCheckPoint callback

答案 1 :(得分:0)

我找到了一种简单的方法。尽管此方法最初是由其他用户提出的,但对于模型训练的重置以及其加载的模型训练时验证准确性一直非常低(并且不是最后一天结束的训练),始终存在寻求者的抱怨。 。可以解决这个问题。我想用一个例子来说明这一点:

假设我的模型定义为:

def get_model_classif_nasnet():
    inputs = Input((224, 224, 3))
    #Other layers not shown here...
    model = Model(inputs, out)
    model.compile(optimizer=Adam(0.0001), loss=binary_crossentropy, metrics=['acc'])
    model.summary()
    return model

在每个时期之后,我们将保存模型进度。为此,我们使用检查点。如果保存的不同检查点具有相关名称,我们也将感到高兴。(例如,该名称应详细说明所经历的培训环境)

h5_path = "weights-improvement-{epoch:02d}-{val_loss:.4f}-{val_acc:.2f}.h5"

checkpoint = ModelCheckpoint(h5_path,
                             monitor='val_acc',
                             verbose=1,
                             save_best_only=True,
                             mode='max'
                            )

现在,让我们使用以上知识来保存模型-

1)训练并保存模型

2)加载

3)继续培训

1)

#Initialize a model
old_model = get_model_classif_nasnet()

#Let's train it
 batch_size = 32

 history = old_model.fit_generator(
    #Training and Validation data...
    epochs=2, verbose=1,
    callbacks=[checkpoint],
    #Some other parameters (not necessarily present in your method)
    steps_per_epoch = len(train) // batch_size,
    validation_steps=len(val) // batch_size
)

您的进度应如下所示: enter image description here

请注意,在第1个阶段之后保存了检查点。现在,让我们假设在第2个阶段的中间结束训练。因此,我们将一个检查点/模型图像另存为 .h5 文件

2)加载

#Again initialize a model
new_model = get_model_classif_nasnet()

3)继续培训

#There is nothing new here
batch_size = 32

 history = new_model.fit_generator(
    #Training and Validation data...
    epochs=8, verbose=1,
    callbacks=[checkpoint],
    #Some other parameters (not necessarily present in your method)
    steps_per_epoch = len(train) // batch_size,
    validation_steps=len(val) // batch_size
)

就是这样。最重要的是,即使完成所有这些操作,您仍然需要确保从一开始就将学习率保持在低水平 optimizer = Adam(0.0001)这是这里的关键。用户引用“正在发生,因为model.save(filename.h5)无法保存优化器的状态。”