pytorch-lightning中的自定义后退/优化步骤

时间:2019-12-13 17:10:27

标签: pytorch

我想在pytorch-lightning中实现下面的训练循环(将其作为伪代码读取)。唯一的特点是不是对每个批次都执行向后和优化步骤。

(背景:我正在尝试实施几次学习算法;尽管我需要在每一步都做出预测-forward方法 -我需要随机执行渐变更新-if-块。

for batch in batches:
    x, y = batch
    loss = forward(x,y)

    optimizer.zero_grad()

    if np.random.rand() > 0.5:
        loss.backward()
        optimizer.step()

我提出的解决方案需要实现backwardoptimizer_step方法,如下所示:

def backward(self, use_amp, loss, optimizer):
        self.compute_grads = False
        if np.random.rand() > 0.5:
            loss.backward()
            nn.utils.clip_grad_value_(self.enc.parameters(), 1)
            nn.utils.clip_grad_value_(self.dec.parameters(), 1)
            self.compute_grads = True
        return


    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
        if self.compute_grads:
            optimizer.step()
            optimizer.zero_grad()   
        return

注意:通过这种方式,我需要在类级别存储compute_grads属性。

在pytorch-lightning中实现它的“最佳实践”方法是什么?有没有更好的方法使用钩子?

1 个答案:

答案 0 :(得分:1)

这是一个好方法!这就是钩子的作用。

有一个新的回调模块也可能会有所帮助: https://pytorch-lightning.readthedocs.io/en/0.7.1/callbacks.html