无法在循环中使用向后挂钩

时间:2019-12-29 02:58:05

标签: python-3.x pytorch

我正在尝试使用后向挂钩在pytorch中实现The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks,以停止传入的梯度以冻结某些权重。手动添加钩子时,一切似乎工作正常,但是当我尝试在循环中执行相同操作时,它显示错误。 mask_dict是一个字典,其中层名称为键,布尔掩码为值。通过将权重的最低pr_value%l1-norm设置为零来创建掩码。下面是代码

def l1norm_mask(model, pr_value, prune_large=False):
    mask_dict={}
    for name, module in model.named_modules():
        if 'linear' in name:
            t = module.weight
            bottomk = torch.topk(torch.abs(t).view(-1), k=round(pr_value*t.nelement()), largest=prune_large)
            clone = t.clone().detach()
            clone.view(-1)[bottomk.indices] = 0
            mask_dict[name] = clone.bool()
    return mask_dict


def apply_mask(model, mask_dict):
    for name, module in model.named_modules():
        if 'linear' in name:
            print('')
            module.weight.data *= mask_dict[name]
            print('module name is:', name, 'and weight size is:', module.weight.size())
            print('corresponding tensor is:', mask_dict[name].shape)
            module.weight.register_hook(lambda x: x*mask_dict[name])  #<-- problem here

#     model.linear1.weight.register_hook(lambda x: x*mask_dict['linear1'])
#     model.linear2.weight.register_hook(lambda x: x*mask_dict['linear2'])
#     model.linear3.weight.register_hook(lambda x: x*mask_dict['linear3'])

mask = l1norm_mask(model, 0.60)
apply_mask(model, mask)
mask.keys()

通过打印一些张量形状,似乎它们都适用于逐元素乘法。输出为:

module name is: linear1 and weight size is: torch.Size([300, 784])
corresponding tensor is: torch.Size([300, 784])

module name is: linear2 and weight size is: torch.Size([100, 300])
corresponding tensor is: torch.Size([100, 300])

module name is: linear3 and weight size is: torch.Size([10, 100])
corresponding tensor is: torch.Size([10, 100])

dict_keys(['linear1', 'linear2', 'linear3'])

尝试训练网络会出现以下错误:

--------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-2e752da442fd> in <module>
----> 1 losses = [update(x,y,3e-3) for x,y in data.train_dl]

<ipython-input-18-2e752da442fd> in <listcomp>(.0)
----> 1 losses = [update(x,y,3e-3) for x,y in data.train_dl]

<ipython-input-17-ca0fbe546149> in update(x, y, lr)
      3     y_hat = model(x)
      4     loss = loss_fn(y_hat, y)
----> 5     loss.backward()
      6     opt.step()
      7     opt.zero_grad()

~/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    164                 products. Defaults to ``False``.
    165         """
--> 166         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    167 
    168     def register_hook(self, hook):

~/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     97     Variable._execution_engine.run_backward(
     98         tensors, grad_tensors, retain_graph, create_graph,
---> 99         allow_unreachable=True)  # allow_unreachable flag
    100 
    101 

<ipython-input-16-46b511a78892> in <lambda>(x)
     35             print('module name is:', name, 'and weight size is:', module.weight.size())
     36             print('correspoding tensor is:', mask_dict[name].shape)
---> 37             module.weight.register_hook(lambda x: x*mask_dict[name])
     38 
     39 #     model.linear1.weight.register_hook(lambda x: x*mask_dict['linear1'])

RuntimeError: The size of tensor a (300) must match the size of tensor b (100) at non-singleton dimension 1

更新:我稍微修改了代码,并打印了传入的渐变形状和循环中通过的蒙版形状。看起来,向后挂钩的调用方式始终采用第一个值,即“ linear1”(= name),因此同一掩码将传递至所有梯度,从而导致矩阵大小不匹配的错误。这是我得到的输出:

shape of grad torch.Size([10, 100])  mask-shape torch.Size([10, 100])

shape of grad torch.Size([100, 300])  mask-shape torch.Size([10, 100])

shape of grad torch.Size([300, 784])  mask-shape torch.Size([10, 100])

shape of grad torch.Size([10, 100])  mask-shape torch.Size([10, 100])

shape of grad torch.Size([100, 300])  mask-shape torch.Size([10, 100])

shape of grad torch.Size([300, 784])  mask-shape torch.Size([10, 100])

此重复出现在…

0 个答案:

没有答案