我有2个网络共享使用不同学习率的一个优化程序。简单代码如下所示:
optim = torch.optim.Adam([
{'params': A.parameters(), 'lr': args.A},
{'params': B.parameters(), 'lr': args.B}])
这是对的吗?我之所以这样问是因为,当我在优化器中检查参数时(使用下面的代码),我发现只有2个参数。
for p in optim.param_groups:
outputs = ''
for k, v in p.items():
if k is 'params':
outputs += (k + ': ' + str(v[0].shape).ljust(30) + ' ')
else:
outputs += (k + ': ' + str(v).ljust(10) + ' ')
print(outputs)
仅打印2个参数:
params: torch.Size([16, 1, 80]) lr: 1e-05 betas: (0.9, 0.999) eps: 1e-08 weight_decay: 0 amsgrad: False
params: torch.Size([30, 10]) lr: 1e-05 betas: (0.9, 0.999) eps: 1e-08 weight_decay: 0 amsgrad: False
实际上,2个网络具有100多个参数。我认为所有参数都将被打印出来。为什么会这样呢?谢谢!
答案 0 :(得分:1)
您只打印每个参数组的第一个张量:
if k is 'params':
outputs += (k + ': ' + str(v[0].shape).ljust(30) + ' ') # only v[0] is printed!
尝试并打印所有参数:
if k is 'params':
outputs += (k + ': ')
for vp in v:
outputs += (str(vp.shape).ljust(30) + ' ')