为什么随着训练的进行保存 state_dict 会变慢?

时间:2021-05-20 12:41:15

标签: python pytorch

我正在按如下方式保存我的模型和优化器的状态字典:

if epoch % 50000 == 0:
  #checkpoint save every 50000 epochs
  print('\nSaving model... Loss is: ', loss)
  torch.save({
      'epoch': epoch,
      'model': self.state_dict(),
      'optimizer_state_dict': self.optimizer.state_dict(),
      'scheduler': self.scheduler.state_dict(),
      'loss': loss,
      'losses': self.losses,
      }, PATH)

当我第一次开始训练时,它可以在不到 5 秒的时间内完成训练。然而,经过几个小时的训练后,它需要两分钟多的时间来保存。我能想到的唯一原因是损失清单。但我看不出这会增加多少时间。

更新 1:
我的损失是:

self.losses = []

我将每个时期的损失附加到此列表中,如下所示:

    #... loss calculation
    loss.backward()
    self.optimizer.step()
    self.scheduler.step() 

    self.losses.append(loss)

1 个答案:

答案 0 :(得分:2)

如评论中所述,指令

self.losses.append(loss) 

绝对是罪魁祸首,应该换成

self.losses.append(loss.item())

原因是,当您存储张量 loss 时,您还同时存储了整个计算图(执行反向传播所需的所有信息)。换句话说,您不仅要存储张量,还要存储指向所有参与计算损失及其关系(添加、相乘等)的张量的指针。所以它会变得非常大非常快。

当您执行 loss.item()(或 loss.detach(),同样有效)时,您将张量从计算图中分离,因此您只存储您想要的内容:损失值本身,如一个简单的浮点值