火炬,训练中学习率的变化

时间:2020-06-25 12:21:13

标签: pytorch


x=np.linspace(0,20,100)


g=1+0.2*np.exp(-0.1*(x-7)**2)
y=np.sin(g*x)

plt.plot(x,y)

plt.show()


x=torch.from_numpy(x)

y=torch.from_numpy(y)

x=x.reshape((100,1))
y=y.reshape((100,1))

MM=nn.Sequential()
MM.add_module('L1',nn.Linear(1,128))
MM.add_module('R1',nn.ReLU())
MM.add_module('L2',nn.Linear(128,128))
MM.add_module('R2',nn.ReLU())
MM.add_module('L3',nn.Linear(128,128))
MM.add_module('R3',nn.ReLU())
MM.add_module('L4',nn.Linear(128,128))
MM.add_module('R5',nn.ReLU())
MM.add_module('L5',nn.Linear(128,1))
MM.double()
L=nn.MSELoss()

lr=3e-05           ######
opt=torch.optim.Adam(MM.parameters(),lr)     #########
Epo=[]
COST=[]

for epoch in range(8000):

  opt.zero_grad()
  err=L(torch.sin(MM(x)),y)
  Epo.append(epoch)
  COST.append(err)
  err.backward()
  if epoch%100==0:
    print(err)
  opt.step()


Epo=np.array(Epo)/1000.
COST=np.array(COST)
pred=torch.sin(MM(x)).detach().numpy()
Trans=MM(x).detach().numpy()
x=x.reshape((100))
pred=pred.reshape((100))
Trans=Trans.reshape((100))

fig = plt.figure(figsize=(10,10))
#ax = fig.gca(projection='3d')
ax = fig.add_subplot(2,2,1)
surf = ax.plot(x,y,'r')
    
    #ax.plot_surface(x_dat,y_dat,z_pred)
    #ax.plot_wireframe(x_dat,y_dat,z_pred,linewidth=0.1)
fig.tight_layout()
    #plt.show()
ax = fig.add_subplot(2,2,2)
surf = ax.plot(x,pred,'g')
fig.tight_layout()

ax = fig.add_subplot(2,2,3)
surff=ax.plot(Epo,COST,'y+')
plt.ylim(0,1100)

ax = fig.add_subplot(2,2,4)
surf = ax.plot(x,Trans,'b')
fig.tight_layout()

plt.show()

这是原始代码1。 为了在训练过程中改变学习速度,我尝试将“ opt”的位置移动为

Epo=[]
COST=[]

for epoch in range(8000):
  lr=3e-05           ######
  opt=torch.optim.Adam(MM.parameters(),lr)     #########
  opt.zero_grad()
  err=L(torch.sin(MM(x)),y)
  Epo.append(epoch)
  COST.append(err)
  err.backward()
  if epoch%100==0:
    print(err)
  opt.step()

这是代码2。 代码2也可以运行,但是结果与代码1完全不同。

有什么区别,并且要在训练过程中改变学习率(如lr =(1-epoch / 10000 * 0.99),该怎么办?

1 个答案:

答案 0 :(得分:1)

您不应该将优化器定义移入训练循环,因为优化器会保留许多与训练历史相关的信息,例如,在Adam的情况下,梯度的运行平均值会在优化器的内部机制中动态存储和更新,... 因此,每次迭代都安装一个新的优化程序会使您失去此历史记录跟踪。

要动态更新学习速率,在pytorch中提出了许多调度程序类(指数衰减,循环衰减,余弦退火等)。您可以从文档中查看它们以获取调度程序的完整列表,也可以根据需要实施自己的调度程序:https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

文档中的示例:通过将学习率乘以0.5(每10个历元)来衰减学习率,您可以如下使用StepLR调度程序:

opt = torch.optim.Adam(MM.parameters(), lr)
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5) 

在原始代码1中,您可以执行以下操作:

for epoch in range(8000):
  opt.zero_grad()
  err=L(torch.sin(MM(x)),y)
  Epo.append(epoch)
  COST.append(err)
  err.backward()
  if epoch%100==0:
    print(err)
  opt.step()
  scheduler.step()

正如我所说的,您还有许多其他类型的lr调度程序,因此您可以从文档中选择或实施自己的