我正在按如下方式保存我的模型和优化器的状态字典:
if epoch % 50000 == 0:
#checkpoint save every 50000 epochs
print('\nSaving model... Loss is: ', loss)
torch.save({
'epoch': epoch,
'model': self.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
'loss': loss,
'losses': self.losses,
}, PATH)
当我第一次开始训练时,它可以在不到 5 秒的时间内完成训练。然而,经过几个小时的训练后,它需要两分钟多的时间来保存。我能想到的唯一原因是损失清单。但我看不出这会增加多少时间。
更新 1:
我的损失是:
self.losses = []
我将每个时期的损失附加到此列表中,如下所示:
#... loss calculation
loss.backward()
self.optimizer.step()
self.scheduler.step()
self.losses.append(loss)
答案 0 :(得分:2)
如评论中所述,指令
self.losses.append(loss)
绝对是罪魁祸首,应该换成
self.losses.append(loss.item())
原因是,当您存储张量 loss
时,您还同时存储了整个计算图(执行反向传播所需的所有信息)。换句话说,您不仅要存储张量,还要存储指向所有参与计算损失及其关系(添加、相乘等)的张量的指针。所以它会变得非常大非常快。
当您执行 loss.item()
(或 loss.detach()
,同样有效)时,您将张量从计算图中分离,因此您只存储您想要的内容:损失值本身,如一个简单的浮点值