如何在pytorch-lightning中使用lbfgs优化器?

时间:2019-09-26 11:26:24

标签: pytorch

我在将pytorch的LBFGS优化器与闪电配合使用时遇到问题。 我使用here中的模板开始一个新项目,这是我尝试过的代码(仅培训部分):

def training_step(self, batch, batch_nb):
    x, y = batch
    x = x.float()
    y = y.float()
    y_hat = self.forward(x)
    return {'loss': F.mse_loss(y_hat, y)}

def configure_optimizers(self):
    optimizer = torch.optim.LBFGS(self.parameters())
    return optimizer

def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i):
    def closure():
        optimizer.zero_grad()
        l = self.training_step(batch, batch_nb)
        loss = l['loss']
        loss.backward()
        return loss

    optimizer.step(closure)

pytorch的LBFGS优化器需要一个闭包函数(请参见herehere),但我不知道如何在模板中定义它,特别是我不知道该批处理如何数据传递给优化器。我试图定义一个自定义的optimizer_step函数,但是在将批处理传递到闭包函数中时遇到一些问题。

对于能帮助我解决此问题或为我指明正确方向的任何建议,我将非常感谢。

环境:
  • PyTorch版本:1.2.0 + cpu
  • 闪电版本:0.4.9
  • 试管版本:0.7.1

1 个答案:

答案 0 :(得分:0)

#310中添加了对lbfgs优化器的支持,现在不需要定义闭包函数。