如何在pytorch中应用变量的指数移动平均值衰减?

时间:2017-12-06 16:44:08

标签: deep-learning pytorch

我正在阅读以下文件。它使用EMA衰减变量 Bidirectional Attention Flow for Machine Comprehension

  

在训练期间,模型的所有权重的移动平均值为   保持指数衰减率为0.999。

他们使用TensorFlow,我找到了相关的EMA代码 https://github.com/allenai/bi-att-flow/blob/master/basic/model.py#L229

在PyTorch中,如何将EMA应用于变量?

1 个答案:

答案 0 :(得分:-2)

移动平均线是梯度下降动量的关键概念。

PyTorch document中,您可以找到:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

将参数momentum更改为您想要的值。