pytorch获取模型的所有层

时间:2019-02-23 22:31:13

标签: python pytorch

获取pytorch模型并获取没有任何nn.Sequence分组的所有层的列表的最简单方法是什么?例如,更好的方法吗?

import pretrainedmodels
model = pretrainedmodels.__dict__['xception'](num_classes=1000, pretrained='imagenet')

l = []
def unwrap_model(model):
    for i in children(model):
        if isinstance(i, nn.Sequential): unwrap_model(i)
        else: l.append(i)
unwrap_model(model)            

print(l)

3 个答案:

答案 0 :(得分:5)

您可以使用modules()方法遍历模型的所有模块。它也位于每个Sequantial内部。

l = [module for module in model.modules() if type(module) != nn.Sequential]

这是一个简单的例子:

model = nn.Sequential(nn.Linear(2, 2), 
                      nn.ReLU(),
                      nn.Sequential(nn.Linear(2, 1), nn.Sigmoid()))

输出:

[Linear(in_features=2, out_features=2, bias=True),
 ReLU(),
 Linear(in_features=2, out_features=1, bias=True),
 Sigmoid()]

答案 1 :(得分:1)

我这样做是这样的:

def flatten(el):
    flattened = [flatten(children) for children in el.children()]
    res = [el]
    for c in flattened:
        res += c
    return res

cnn = nn.Sequential(Custom_block_1, Custom_block_2)
layers = flatten(cnn)

答案 2 :(得分:0)

我将其划分为一个更深层的模型,并非所有块都来自nn.sequential。

def get_children(model: torch.nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children