我正在训练一个 ResNet 模型来对汽车品牌进行分类。
我在训练期间保存了每个 epoch 的权重。
为了测试,我在 epoch 3 停止训练。
# checkpoint = ModelCheckpoint("best_model.hdf5", monitor='loss', verbose=1)
checkpoint_path = "weights/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, verbose=1,
# Save weights, every epoch.
save_freq='epoch')
model.save_weights(checkpoint_path.format(epoch=0))
history = model.fit_generator(
training_set,
validation_data = test_set,
epochs = 50,
steps_per_epoch = len(training_set),
validation_steps = len(test_set),
callbacks = [cp_callback]
)
但是,在加载它们时,我不确定它是否从保存的最后一个纪元恢复,因为它再次显示纪元 1/50。下面是我用来加载最后保存的模型的代码。
from keras.models import Sequential, load_model
# load the model
new_model = load_model('./weights/cp-0003.ckpt')
# fit the model
history = new_model.fit_generator(
training_set,
validation_data = test_set,
epochs = 50,
steps_per_epoch = len(training_set),
validation_steps = len(test_set),
callbacks = [cp_callback]
)
这是它的样子: Image showing that running the saved weight starts from epoch 1/50 again
有人可以帮忙吗?
答案 0 :(得分:0)
您可以使用 fit_generator 的 initial_epoch
参数。默认情况下,它设置为 0,但您可以将其设置为任何正数:
from keras.models import Sequential, load_model
import tensorflow as tf
checkpoint_path = "weights/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, verbose=1,
# Save weights, every epoch.
save_freq='epoch')
model.save_weights(checkpoint_path.format(epoch=0))
history = model.fit_generator(
training_set,
validation_data=test_set,
epochs=3,
steps_per_epoch=len(training_set),
validation_steps=len(test_set),
callbacks = [cp_callback]
)
new_model = load_model('./weights/cp-0003.ckpt')
# fit the model
history = new_model.fit_generator(
training_set,
validation_data=test_set,
epochs=50,
steps_per_epoch=len(training_set),
validation_steps=len(test_set),
callbacks=[cp_callback],
initial_epoch=3
)
这将训练您的模型 50 - 3 = 47 个额外的 epoch。
如果您使用 Tensorflow 2.X,则有关您的代码的一些说明:
fit_generator
已弃用,因为 fit
现在支持生成器from keras....
替换为 from tensorflow.keras...