Pytorch:从批处理加载检查点而无需再次迭代数据集

时间:2021-04-13 09:56:24

标签: neural-network pytorch

我需要能够从批处理中加载,而不是从时代明智的检查点加载。我知道这不是最佳选择,但由于在我的培训中断之前我只有有限的培训时间(google colab 免费版),我需要能够从它停止的批次或该批次附近加载。

我也不想再次遍历所有数据,而是继续处理模型尚未看到的数据。

我目前行不通的方法:

def save_checkpoint(state, file=checkpoint_file):
  torch.save(state, file)

def load_checkpoint(checkpoint):
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  train_loss = checkpoint['train_loss']
  val_loss = checkpoint['val_loss']
  epoch = checkpoint['epoch']
  step = checkpoint['step']
  batch = checkpoint['batch']
  return model, optimizer, train_loss, val_loss, epoch, step, batch

虽然它确实从停止的地方加载了权重,但它会再次迭代所有数据。

另外,我什至需要捕获 train_lossval_loss 吗?无论是否包含它们,我都看不到输出损失的差异。因此,我假设它已经包含在 model.load_state_dict (?)

我假设捕获 step 和 batch 不会以这种方式工作,我实际上需要在我的 class DataSet 中包含某种索引跟踪器?我已经在 DataSet 班级

中有这个
   def __getitem__(self, idx):
    question = self.data_qs[idx]
    answer1 = self.data_a1s[idx]
    answer2 = self.data_a2s[idx]
    target = self.targets[idx]

那么,这有用吗?

1 个答案:

答案 0 :(得分:1)

您可以通过创建具有属性 self.start_index=step*batch 的自定义数据集类来实现您的目标,并且在您的 __getitem__ 函数中,新索引应为 (self.start_index+idx)%len(self.data_qs) 如果您使用 shuffle=False 创建数据加载器,那么此技巧将起作用。

此外,您可以使用 shuffle=True 维护索引映射器并需要验证。