如何查看火炬手亚当的适应学习率?

时间:2020-05-13 11:14:08

标签: pytorch

自适应学习率方法有很多optimizers。是否有可能看到亚当的初始学习率的调整值?

Here是关于Adadelta的类似问题,答案是搜索["acc_delta"]键,但是Adam没有该键。

1 个答案:

答案 0 :(得分:1)

AFAIK没有超级简单的方法可以做到这一点。但是,您可以使用PyTorch中的Adam实现重新计算某个参数的当前学习率:https://pytorch.org/docs/stable/_modules/torch/optim/adam.html

我想到了这个最小的工作示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

def get_current_lr(optimizer, group_idx, parameter_idx):
    # Adam has different learning rates for each paramter. So we need to pick the
    # group and paramter first.
    group = optimizer.param_groups[group_idx]
    p = group['params'][parameter_idx]

    beta1, _ = group['betas']
    state = optimizer.state[p]

    bias_correction1 = 1 - beta1 ** state['step']
    current_lr = group['lr'] / bias_correction1
    return current_lr

x = Variable(torch.randn(100, 1)) #Just create a random tensor as input
model = nn.Linear(1, 1)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
niter = 20
for _ in range(0, niter):
    out = model(x)

    optimizer.zero_grad()
    loss = criterion(out, x) #Here we learn the identity mapping
    loss.backward()
    optimizer.step()
    group_idx, param_idx = 0, 0
    current_lr = get_current_lr(optimizer, group_idx, param_idx)
    print('Current learning rate (g:%d, p:%d): %.4f | Loss: %.4f'%(group_idx, param_idx, current_lr, loss.item()))

应该输出如下内容:

Current learning rate (g:0, p:0): 0.0100 | Loss: 0.5181
Current learning rate (g:0, p:0): 0.0053 | Loss: 0.5161
Current learning rate (g:0, p:0): 0.0037 | Loss: 0.5141
Current learning rate (g:0, p:0): 0.0029 | Loss: 0.5121
Current learning rate (g:0, p:0): 0.0024 | Loss: 0.5102
Current learning rate (g:0, p:0): 0.0021 | Loss: 0.5082
Current learning rate (g:0, p:0): 0.0019 | Loss: 0.5062
Current learning rate (g:0, p:0): 0.0018 | Loss: 0.5042
Current learning rate (g:0, p:0): 0.0016 | Loss: 0.5023
Current learning rate (g:0, p:0): 0.0015 | Loss: 0.5003
Current learning rate (g:0, p:0): 0.0015 | Loss: 0.4984
Current learning rate (g:0, p:0): 0.0014 | Loss: 0.4964
Current learning rate (g:0, p:0): 0.0013 | Loss: 0.4945
Current learning rate (g:0, p:0): 0.0013 | Loss: 0.4925
Current learning rate (g:0, p:0): 0.0013 | Loss: 0.4906
Current learning rate (g:0, p:0): 0.0012 | Loss: 0.4887
Current learning rate (g:0, p:0): 0.0012 | Loss: 0.4868
Current learning rate (g:0, p:0): 0.0012 | Loss: 0.4848
Current learning rate (g:0, p:0): 0.0012 | Loss: 0.4829
Current learning rate (g:0, p:0): 0.0011 | Loss: 0.4810

请注意,监视每个参数的学习率可能不可行,对大型模型也没有帮助。