在Pytorch中的两个图层(模块)之间共享权重的正确方法是什么?
根据我在Pytorch讨论论坛上的发现,有几种方法可以做到这一点。
例如,基于this discussion,我认为只需分配转置权重即可做到。这样做:
self.decoder[0].weight = self.encoder[0].weight.t()
但是,这被证明是错误的,并导致错误。
然后,我尝试将以上行包装在nn.Parameter()
中:
self.decoder[0].weight = nn.Parameter(self.encoder[0].weight.t())
这消除了错误,但是再次,这里没有共享发生。这样,我刚刚使用与encoder[0].weight.t()
相同的值初始化了一个 new 张量。
然后我发现了这个link,它提供了不同的权重分配方法。但是,我怀疑给出的所有方法是否正确。
例如,这样一种方式被演示:
# tied autoencoder using off the shelf nn modules
class TiedAutoEncoderOffTheShelf(nn.Module):
def __init__(self, inp, out, weight):
super().__init__()
self.encoder = nn.Linear(inp, out, bias=False)
self.decoder = nn.Linear(out, inp, bias=False)
# tie the weights
self.encoder.weight.data = weight.clone()
self.decoder.weight.data = self.encoder.weight.data.transpose(0,1)
def forward(self, input):
encoded_feats = self.encoder(input)
reconstructed_output = self.decoder(encoded_feats)
return encoded_feats, reconstructed_output
基本上,它使用nn.Parameter()
创建一个新的权重张量,并将其分配给每个层/模块,如下所示:
weights = nn.Parameter(torch.randn_like(self.encoder[0].weight))
self.encoder[0].weight.data = weights.clone()
self.decoder[0].weight.data = self.encoder[0].weight.data.transpose(0, 1)
这真的让我感到困惑,这如何在这两层之间共享相同的变量?
不仅仅是克隆“原始” 数据吗?
当我使用这种方法并可视化权重时,我注意到可视化效果是不同的,这使我更加确定某些事情是不正确的。
我不确定不同的可视化是否仅是由于一个是另一个的转置,还是我刚刚已经怀疑过,它们是独立优化的(即,权重不在层之间共享)
答案 0 :(得分:1)
这可以通过PyTorch挂钩进行,您可以在其中更新A的前向挂钩以更改WB,还可以将WB冻结在M2 autograd中。
所以只需使用钩子即可。
from time import sleep
import torch
import torch.nn as nn
class M(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(1,2)
def forward(self, x):
x = self.l1(x)
return x
model = M()
model.train()
def printh(module, inp, outp):
sleep(1)
print("update other model parameter in here...")
h = model.register_forward_hook(printh)
for i in range(1,4):
x = torch.randn(1)
output = model(x)
h.remove()
答案 1 :(得分:1)
有趣的是,你的第一直觉是对的@Rika:
<块引用>这真的让我很困惑,这两个层之间如何共享相同的变量?不只是克隆“原始”数据吗?
实际上很多人在博客文章或他们自己的存储库中都弄错了。
还有
self.decoder[0].weight = nn.Parameter(self.encoder[0].weight.t())
将简单地创建一个新的权重矩阵,如您所写。
唯一可行的操作过程似乎是使用由 nn.Linear (torch.nn.functional.linear()
) 调用的线性函数:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
# Real off-the-shelf tied linear module
class TiedLinear(nn.Module):
def __init__(self, tied_to: nn.Linear, bias: bool = True):
super().__init__()
self.tied_to = tied_to
if bias:
self.bias = nn.Parameter(torch.Tensor(tied_to.in_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
# copied from nn.Linear
def reset_parameters(self):
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.tied_to.weight.t())
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.tied_to.weight.t(), self.bias)
# To keep module properties intuitive
@property
def weight(self) -> torch.Tensor:
return self.tied_to.weight.t()
# Shared weights, different biases
encoder = nn.Linear(in, out)
decoder = TiedLinear(encoder)
答案 2 :(得分:0)