如何获取与PyTorch中的状态字典匹配的图层中的要素的值?

时间:2017-11-13 09:18:16

标签: python pytorch

我有一些cnn,我想从状态字典中获取对应于某个键的某个中间层的值。 怎么可以这样做? 感谢。

1 个答案:

答案 0 :(得分:1)

我认为您需要创建一个新类,重新定义给定模型的正向传递。但是,很可能需要创建有关模型架构的代码。你可以在这里找到一个例子:

class extract_layers():

    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer

    def __call__(self, x):
        return self.forward(x)

    def forward(self, x):
        module = self.model._modules[self.target_layer]

        # get output of the desired layer
        features = module(x)

        # get output of the whole model
        x = self.model(x)

        return x, features


model = models.vgg19(pretrained=True)
target_layer = 'features'
extractor = extract_layers(model, target_layer)

image = Variable(torch.randn(1, 3, 244, 244))
x, features = extractor(image)

在这种情况下,我使用的是pytorch models zoo中给出的预定义的vgg19网络。网络的层次结构分为两个模块:features用于卷积部分,classifier用于完全连接的部分。在这种情况下,由于features包裹了网络的所有卷积层,因此很简单。如果您的体系结构有多个具有不同名称的图层,则需要使用与此类似的内容存储其输出:

 for name, module in self.model._modules.items():
    x = module(x)  # forward the module individually
    if name in self.target_layer:
        features = x  # store the output of the desired layer

此外,您应该记住,您需要重新整形将卷积部分连接到完全连接的层的输出。如果您知道该图层的名称,应该很容易做到。