Pytorch:如何在下一步/循环中暂时保存以前的参数和渐变?

时间:2018-03-19 20:08:56

标签: python deep-learning pytorch

我想知道如何保存参数/权重及其相关梯度(在每个步骤/循环中从对象优化器(例如,optimizer = optim.SGD ....)使用backward()之后)在下一步将它们重新加载回优化器。因为我需要保存并重新加载参数&它们的渐变每一步/循环时间,我不想将它们保存到单独的pt或pth文件中,并在每个循环中加载时间。

我尝试实现代码:

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        def closure():
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            return loss.data[0]
        # I saved parameters & their gradients here! I also modified the return stuff in sgd.py as well, which is not shown here.
        running_loss, param_groups_temp_prev, param_groups_grad_temp_prev = optimizer.step(closure)
        # Here I tried to re-assign the above saved two lists back to optimizer again for the second call of optimizer.step(closure)                                                          
        optimizer.param_groups[0]['params'] = copy.deepcopy([pgt_prev for pgt_prev in param_groups_temp_prev])
        for ii in range(len(param_groups_grad_temp_prev)):
            pggt_prev = copy.deepcopy(param_groups_grad_temp_prev[ii])
            optimizer.param_groups[0]['params'][ii].grad = copy.deepcopy(Variable(pggt_prev))
        running_loss_temp, param_groups_temp1, param_groups_grad_temp1 = optimizer.step(closure)

我将它们保存为param_groups_temp_prev和param_groups_grad_temp_prev,以便在第二次重新加载它们以在同一循环(第一次循环)中调用optimizer.step(closure),一切顺利。

然而,当它进入循环运行到第二次时,在调用optimizer.zero_grad()之后仍然很好,所有参数的渐变都被清除为零。但是,在调用了loss.backward()之后,渐变仍然为零,这似乎是向后的()失效了!

有人可以帮我解决这个问题吗?欣赏!

1 个答案:

答案 0 :(得分:1)

我通常按照以下方法保存所有内容,以便以后可以恢复操作。

def save_checkpoint(state, filename='./checkpoint.pth.tar'):
    if os.path.isfile(filename):
        os.remove(filename)
    torch.save(state, filename)

save_checkpoint({
    'epoch': (epoch + 1),
    'state_dict': self.model.state_dict(),
    'best_acc': self.best_dev_acc,
    'optimizer': self.optimizer.state_dict(),
}, self.config.save_path + 'model_best.pth.tar')

我按如下方式加载状态。

checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_acc = checkpoint['best_acc']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

此处,args.resume是文件路径。