以下是来自fastAI实现的weightDrop的代码。可以正常工作。但是我在set_weights函数中做了很小的更改,但它却中断并给出Value ERROR:无法优化非叶Tensor。
以下是体重减轻lstm的代码。它的作用是将元素随机地任意放入权重矩阵
class WeightDropout(Module):
"A module that warps another layer in which some weights will be replaced by 0 during training."
def __init__(self, module, weight_p, layer_names='weight_hh_l0'):
self.module,self.weight_p,self.layer_names = module,weight_p,L(layer_names)
for layer in self.layer_names:
#Makes a copy of the weights of the selected layers.
w = getattr(self.module, layer)
delattr(self.module, layer)
self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
setattr(self.module, layer, F.dropout(w.data, p=self.weight_p, training=False))
if isinstance(self.module, (nn.RNNBase, nn.modules.rnn.RNNBase)):
self.module.flatten_parameters = self._do_nothing
def _setweights(self):
"Apply dropout to the raw weights."
for layer in self.layer_names:
raw_w = getattr(self, f'{layer}_raw')
setattr(self.module, layer, F.dropout(raw_w.data, p=self.weight_p, training=self.training))
def forward(self, *args):
self._setweights()
with warnings.catch_warnings():
#To avoid the warning that comes because the weights aren't flattened.
warnings.simplefilter("ignore")
return self.module.forward(*args)
def reset(self):
for layer in self.layer_names:
raw_w = getattr(self, f'{layer}_raw')
setattr(self.module, layer, F.dropout(raw_w.data, p=self.weight_p, training=False))
if hasattr(self.module, 'reset'): self.module.reset()
def _do_nothing(self): pass
我正在尝试实现变体辍学:变体辍学是一种辍学,它在每个时间步都使用相同的辍学面具。 Pytorch完全按照AdHoc的方式使用dropout,如图所示(朴素的Dropout),这是错误的,并且会产生不稳定的结果。在变差dropout中,我们应该将矩阵的行清零(这很重要)。
这与上面的WeightDrop实施极为相似。我们所做的唯一更改是,不是要丢弃概率为 p 的矩阵中的每个元素,而是要丢弃概率为 p
的每一行
所以我所做的唯一更改是
def _setweights(self):
"Apply dropout to the raw weights."
for layer in self.layer_names:
raw_w = getattr(self, f'{layer}_raw')
"""
which is the modification that I made for implementing **variation dropout**
"""
N,K = raw_w.shape
mask = F.dropout(torch.ones(N,1),p=self.weight_p,training= self.training)
mask = mask.repeat(1,K)
new = raw_w * mask
setattr(self.module, layer, new)
正如我所说,它给出了ValueERROR。有人可以从这里帮助我