我有 Model 和 Trainer pytorch-lightning 对象,它们初始化如下:
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join('experiments', experiment_name, '{epoch}-{avg_valid_iou:.4f}'),
save_top_k=1,
verbose=True,
monitor='avg_valid_iou',
mode='max',
prefix=''
)
model = nn.DataParallel (FaultNetPL(batch_size=5)).cuda()
trainer = Trainer( checkpoint_callback=checkpoint_callback,
max_epochs=500,gpus=1,
logger=logger)
然后我开始使用:
trainer.fit(model)
但是训练被中断了,现在我想使用第 N 次迭代的检查点来恢复它 所以我尝试将模型和训练器初始化为:
model = FaultNetPL.load_from_checkpoint('experiments/VNET/epoch=77-avg_valid_iou=0.7604.ckpt',batch_size=5)
trainer = Trainer(resume_from_checkpoint = 'epoch=77-avg_valid_iou=0.7604.ckpt',
checkpoint_callback=checkpoint_callback,
max_epochs=500,gpus=1,
logger=logger)
但是一次又一次地从头开始训练(从第 0 个纪元开始,错误巨大)。我错过了什么?
答案 0 :(得分:0)
您不仅应该保存模型状态,还应该保存优化器状态和起始时期值。例如:
state({
'epoch': epoch + 1,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
})
保存检查点后,您可以通过以下命令继续训练:
checkpoint = torch.load(state_file)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_val = checkpoint['epoch']
for epoch in range(start_val, max_val):
...
...