从pytorch中的特定层获取输出

时间:2021-02-04 04:50:37

标签: python-3.x neural-network pytorch autoencoder

我在 Pytorch 中实现了一个自动编码器,并希望从指定的编码层中提取表示(输出)。这种设置类似于使用我们过去在 Keras 中拥有的子模型进行预测。

然而,在 Pytorch 中实现类似的东西看起来有点挑战。我尝试了 How to get the output from a specific layer from a PyTorch model?https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html 中所述的前向钩子,但无济于事。

你能帮我从特定层获取输出吗?

我在下面附上了我的代码:

class Autoencoder(torch.nn.Module):

    # Now defining the encoding and decoding layers.

    def __init__(self):
        super().__init__()   
        self.enc1 = torch.nn.Linear(in_features = 784, out_features = 256)
        self.enc2 = torch.nn.Linear(in_features = 256, out_features = 128)
        self.enc3 = torch.nn.Linear(in_features = 128, out_features = 64)
        self.enc4 = torch.nn.Linear(in_features = 64, out_features = 32)
        self.enc5 = torch.nn.Linear(in_features = 32, out_features = 16)
        self.dec1 = torch.nn.Linear(in_features = 16, out_features = 32)
        self.dec2 = torch.nn.Linear(in_features = 32, out_features = 64)
        self.dec3 = torch.nn.Linear(in_features = 64, out_features = 128)
        self.dec4 = torch.nn.Linear(in_features = 128, out_features = 256)
        self.dec5 = torch.nn.Linear(in_features = 256, out_features = 784)

    # Now defining the forward propagation step

    def forward(self,x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = F.relu(self.enc5(x))
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        x = F.relu(self.dec5(x))
    
        return x

autoencoder_network = Autoencoder()

我必须从标记为 enc1、enc2 ..、enc5 的编码器层中获取输出。

2 个答案:

答案 0 :(得分:1)

最简单的方法是显式返回您需要的激活:

    def forward(self,x):
        e1 = F.relu(self.enc1(x))
        e2 = F.relu(self.enc2(e1))
        e3 = F.relu(self.enc3(e2))
        e4 = F.relu(self.enc4(e3))
        e5 = F.relu(self.enc5(e4))
        x = F.relu(self.dec1(e5))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        x = F.relu(self.dec5(x))
    
        return x, e1, e2, e3, e4, e5

答案 1 :(得分:0)

您可以定义一个全局字典,例如 activations = {},然后在 forward 函数中为其赋值,例如 activations['enc1'] = x.clone().detach() 等。