获取pytorch模型并获取没有任何nn.Sequence分组的所有层的列表的最简单方法是什么?例如,更好的方法吗?
import pretrainedmodels
model = pretrainedmodels.__dict__['xception'](num_classes=1000, pretrained='imagenet')
l = []
def unwrap_model(model):
for i in children(model):
if isinstance(i, nn.Sequential): unwrap_model(i)
else: l.append(i)
unwrap_model(model)
print(l)
答案 0 :(得分:5)
您可以使用modules()方法遍历模型的所有模块。它也位于每个Sequantial
内部。
l = [module for module in model.modules() if type(module) != nn.Sequential]
这是一个简单的例子:
model = nn.Sequential(nn.Linear(2, 2),
nn.ReLU(),
nn.Sequential(nn.Linear(2, 1), nn.Sigmoid()))
输出:
[Linear(in_features=2, out_features=2, bias=True),
ReLU(),
Linear(in_features=2, out_features=1, bias=True),
Sigmoid()]
答案 1 :(得分:1)
我这样做是这样的:
def flatten(el):
flattened = [flatten(children) for children in el.children()]
res = [el]
for c in flattened:
res += c
return res
cnn = nn.Sequential(Custom_block_1, Custom_block_2)
layers = flatten(cnn)
答案 2 :(得分:0)
我将其划分为一个更深层的模型,并非所有块都来自nn.sequential。
def get_children(model: torch.nn.Module):
# get children form model!
children = list(model.children())
flatt_children = []
if children == []:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
return flatt_children