我正在努力在实施SGD的过程中实现势头。 据我了解,此更新看起来像这样:
parameters -= (lr * (p.grad*0.1 + p_delta_prev*0.9))
我的问题是我应该如何存储每次更新中以前的增量
这是我的更新功能:
#we now want to do the update with momentum
#momentum takes derivative, multiplies it by 0.1, then takes the previous update,
#multiplies it by 0.9 and we add the two together
#alpha = 0.1, beta = 0.9; p-=grad*0.1 + p*0.9
def update(x,y,lr):
wd = 1e-5
y_hat = model(x)
# weight decay
w2 = 0.
for p in model.parameters(): w2 += (p**2).sum()
# add to regular loss
loss = loss_func(y_hat, y) + w2*wd
loss.backward()
with torch.no_grad():
for p in model.parameters():
#p.grad is the slope of the line of that parameter
#current_p-previous_p to get difference
p_update = (lr * (p.grad*0.1 + p*0.9))
p.sub_(p_update)
p.grad.zero_()
return loss.item()
此处p*0.9
应该替换为p_delta_prev。但是我应该如何为每个参数存储这些增量?如果将它们保存到张量中,我将无法有效地将权重增量复制到内存中,从而使模型的大小变为原来的两倍。什么是完成此任务的好方法?我不想使用为我执行该激活的内置函数。我确实查看了pytorch sgd.py,它看起来像是存储状态。
我已更新代码:
#we now want to do the update with momentum
#momentum takes derivative, multiplys it by 0.1, then takes the previous update,
#multiplies it by 0.9 and we add the two together
#alpha = 0.1, beta = 0.9; p-=grad*0.1 + p*0.9
p_delta = {}
def update(x,y,lr):
wd = 1e-5
y_hat = model(x)
# weight decay
w2 = 0.
for p in model.parameters(): w2 += (p**2).sum()
# add to regular loss
loss = loss_func(y_hat, y) + w2*wd
loss.backward()
with torch.no_grad():
i = 0
for p in model.parameters():
#p.grad is the slope of the line of that parameter
if i not in p_delta:#check if key exists
p_delta[i] = torch.zeros_like(p)
p_update = (lr *p.grad) + (p_delta[i]*0.9)
p_delta[i] = p_update.clone()
p.sub_(p_update)
p.grad.zero_()
print((p_delta[i]))
i+=1
return loss.item()
我认为Excel电子表格中的代码不正确。杰里米(Jeremy)似乎显示:lr* ((p.grad*0.1) + (p_delta[i]*0.9))
,但许多教程似乎都显示:(lr *p.grad) + (p_delta[i]*0.9)
如果我们实施杰里米(Jeremy)的代码,则损失实际上要比普通GD慢。视频的一部分在这里:https://youtu.be/CJKnDu2dxOE?t=6581
答案 0 :(得分:1)
是的,它确实将参数momenta存储在由model.named_parameters()
返回的按其名称索引的字典中。我不知道如何严格地证明这一点,但是我坚信,如果不使用两倍于模型大小的额外内存,就不可能施加动量。
话虽如此,我不会担心,因为模型大小很少是影响整个算法内存的主要因素-为反向传播算法保留中间网络激活的成本要高得多。以VGG-16网络为例,它具有1.38亿个参数(数据来自here),如果以单精度存储,则略大于0.5gb。将此与合理的现代GPU上的6gb +进行比较。