我需要能够从批处理中加载,而不是从时代明智的检查点加载。我知道这不是最佳选择,但由于在我的培训中断之前我只有有限的培训时间(google colab 免费版),我需要能够从它停止的批次或该批次附近加载。
我也不想再次遍历所有数据,而是继续处理模型尚未看到的数据。
我目前行不通的方法:
def save_checkpoint(state, file=checkpoint_file):
torch.save(state, file)
def load_checkpoint(checkpoint):
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
train_loss = checkpoint['train_loss']
val_loss = checkpoint['val_loss']
epoch = checkpoint['epoch']
step = checkpoint['step']
batch = checkpoint['batch']
return model, optimizer, train_loss, val_loss, epoch, step, batch
虽然它确实从停止的地方加载了权重,但它会再次迭代所有数据。
另外,我什至需要捕获 train_loss
和 val_loss
吗?无论是否包含它们,我都看不到输出损失的差异。因此,我假设它已经包含在 model.load_state_dict
(?)
我假设捕获 step 和 batch 不会以这种方式工作,我实际上需要在我的 class DataSet
中包含某种索引跟踪器?我已经在 DataSet
班级
def __getitem__(self, idx):
question = self.data_qs[idx]
answer1 = self.data_a1s[idx]
answer2 = self.data_a2s[idx]
target = self.targets[idx]
那么,这有用吗?
答案 0 :(得分:1)
您可以通过创建具有属性 self.start_index=step*batch
的自定义数据集类来实现您的目标,并且在您的 __getitem__
函数中,新索引应为 (self.start_index+idx)%len(self.data_qs)
如果您使用 shuffle=False
创建数据加载器,那么此技巧将起作用。
此外,您可以使用 shuffle=True
维护索引映射器并需要验证。