加载保存的模型以恢复训练

时间:2021-06-22 09:34:11

标签: python tensorflow machine-learning keras resnet

我正在训练一个 ResNet 模型来对汽车品牌进行分类。

我在训练期间保存了每个 epoch 的权重。

为了测试,我在 epoch 3 停止训练。

# checkpoint = ModelCheckpoint("best_model.hdf5", monitor='loss', verbose=1)
checkpoint_path = "weights/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
   checkpoint_path, verbose=1, 
   # Save weights, every epoch.
   save_freq='epoch')

model.save_weights(checkpoint_path.format(epoch=0))

history = model.fit_generator(
    training_set,
    validation_data = test_set,
    epochs = 50,
    steps_per_epoch = len(training_set),
    validation_steps = len(test_set),
    callbacks = [cp_callback]
)

但是,在加载它们时,我不确定它是否从保存的最后一个纪元恢复,因为它再次显示纪元 1/50。下面是我用来加载最后保存的模型的代码。

from keras.models import Sequential, load_model
# load the model
new_model = load_model('./weights/cp-0003.ckpt')

# fit the model
history = new_model.fit_generator(
    training_set,
    validation_data = test_set,
    epochs = 50,
    steps_per_epoch = len(training_set),
    validation_steps = len(test_set),
    callbacks = [cp_callback]
)

这是它的样子: Image showing that running the saved weight starts from epoch 1/50 again

有人可以帮忙吗?

1 个答案:

答案 0 :(得分:0)

您可以使用 fit_generatorinitial_epoch 参数。默认情况下,它设置为 0,但您可以将其设置为任何正数:

from keras.models import Sequential, load_model
import tensorflow as tf

checkpoint_path = "weights/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
   checkpoint_path, verbose=1, 
   # Save weights, every epoch.
   save_freq='epoch')

model.save_weights(checkpoint_path.format(epoch=0))

history = model.fit_generator(
    training_set,
    validation_data=test_set,
    epochs=3,
    steps_per_epoch=len(training_set),
    validation_steps=len(test_set),
    callbacks = [cp_callback]
)


new_model = load_model('./weights/cp-0003.ckpt')

# fit the model
history = new_model.fit_generator(
    training_set,
    validation_data=test_set,
    epochs=50,
    steps_per_epoch=len(training_set),
    validation_steps=len(test_set),
    callbacks=[cp_callback],
    initial_epoch=3
)

这将训练您的模型 50 - 3 = 47 个额外的 epoch。


如果您使用 Tensorflow 2.X,则有关您的代码的一些说明:

  • fit_generator 已弃用,因为 fit 现在支持生成器
  • 您应该将导入的 from keras.... 替换为 from tensorflow.keras...