pytorch:记住几层序列的输出

时间:2017-08-28 12:41:20

标签: python pytorch

我在pytorch中有一个模型,并且从一个前向传递想要提取几个层的输出。这可能吗?

e.g。 vgg

的前五个转换层的输出
import torchvision
vgg = torchvision.models.vgg19_bn()

1 个答案:

答案 0 :(得分:0)

我认为最简单的方法是定义一个新模型,其中定义了所有感兴趣的层:

from torch import nn
import torchvision

class VGG(nn.Module):

    def __init__(self):

        super(VGG, self).__init__()

        vgg = torchvision.models.vgg19_bn()

        self.l_00 = list(vgg.features.children())[0]
        self.l_01 = list(vgg.features.children())[1]
        self.l_02 = list(vgg.features.children())[2]
        self.l_03 = list(vgg.features.children())[3]
        self.l_04 = list(vgg.features.children())[4]
        self.l_05 = list(vgg.features.children())[5]
        self.l_06 = list(vgg.features.children())[6]
        self.l_07 = list(vgg.features.children())[7]
        self.l_08 = list(vgg.features.children())[8]
        self.l_09 = list(vgg.features.children())[9]
        self.l_10 = list(vgg.features.children())[10]
        self.l_11 = list(vgg.features.children())[11]
        self.l_12 = list(vgg.features.children())[12]
        self.l_13 = list(vgg.features.children())[13]
        self.l_14 = list(vgg.features.children())[14]
        self.l_15 = list(vgg.features.children())[15]
        self.l_16 = list(vgg.features.children())[16]

    def forward(self, x):

        x  = self.l_00(x)
        x  = self.l_01(x)
        c1 = self.l_02(x)
        x  = self.l_03(x)
        x  = self.l_04(x)
        c2 = self.l_05(x)
        x  = self.l_06(c2)
        x  = self.l_07(x)
        x  = self.l_08(x)
        c3 = self.l_09(x)
        x  = self.l_10(c3)
        x  = self.l_11(x)
        c4 = self.l_12(x)
        x  = self.l_13(c4)
        x  = self.l_14(x)
        x  = self.l_15(x)
        c5  = self.l_16(x)

        return c1, c2, c3, c4, c5