我想在pytorch-lightning中实现下面的训练循环(将其作为伪代码读取)。唯一的特点是不是对每个批次都执行向后和优化步骤。
(背景:我正在尝试实施几次学习算法;尽管我需要在每一步都做出预测-forward
方法
-我需要随机执行渐变更新-if
-块。
for batch in batches:
x, y = batch
loss = forward(x,y)
optimizer.zero_grad()
if np.random.rand() > 0.5:
loss.backward()
optimizer.step()
我提出的解决方案需要实现backward
和optimizer_step
方法,如下所示:
def backward(self, use_amp, loss, optimizer):
self.compute_grads = False
if np.random.rand() > 0.5:
loss.backward()
nn.utils.clip_grad_value_(self.enc.parameters(), 1)
nn.utils.clip_grad_value_(self.dec.parameters(), 1)
self.compute_grads = True
return
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
if self.compute_grads:
optimizer.step()
optimizer.zero_grad()
return
注意:通过这种方式,我需要在类级别存储compute_grads
属性。
在pytorch-lightning中实现它的“最佳实践”方法是什么?有没有更好的方法使用钩子?
答案 0 :(得分:1)
这是一个好方法!这就是钩子的作用。
有一个新的回调模块也可能会有所帮助: https://pytorch-lightning.readthedocs.io/en/0.7.1/callbacks.html