Pytorch中缺少优化器参数

时间:2019-10-29 01:34:09

标签: pytorch

我有2个网络共享使用不同学习率的一个优化程序。简单代码如下所示:

optim = torch.optim.Adam([
{'params': A.parameters(), 'lr': args.A},
{'params': B.parameters(), 'lr': args.B}])

这是对的吗?我之所以这样问是因为,当我在优化器中检查参数时(使用下面的代码),我发现只有2个参数。

for p in optim.param_groups:
outputs = ''
for k, v in p.items():
    if k is 'params':
        outputs += (k + ': ' + str(v[0].shape).ljust(30) + ' ')
    else:
        outputs += (k + ': ' + str(v).ljust(10) + ' ')
print(outputs)

仅打印2个参数:

params: torch.Size([16, 1, 80])        lr: 1e-05      betas: (0.9, 0.999) eps: 1e-08      weight_decay: 0          amsgrad: False

params: torch.Size([30, 10])           lr: 1e-05      betas: (0.9, 0.999) eps: 1e-08      weight_decay: 0          amsgrad: False

实际上,2个网络具有100多个参数。我认为所有参数都将被打印出来。为什么会这样呢?谢谢!

1 个答案:

答案 0 :(得分:1)

您只打印每个参数组的第一个张量:

if k is 'params':
    outputs += (k + ': ' + str(v[0].shape).ljust(30) + ' ')  # only v[0] is printed!

尝试并打印所有参数:

if k is 'params':
    outputs += (k + ': ')
    for vp in v:
        outputs += (str(vp.shape).ljust(30) + ' ')