在火炬视觉模型中找到所有ReLU层

时间:2018-10-04 00:07:16

标签: python-3.x machine-learning deep-learning computer-vision pytorch

torchvision.models获取经过预训练的模型后,我希望将所有ReLU实例都存储到register_backward_hook(f),如下所示:

for pos, module in self.model.features._modules.items():
    for sub_module in module:
        if isinstance(module, ReLU):
            module.register_backward_hook(f)

对我来说,问题是如何在模型中找到所有ReLU。对于densenet161ReLU不仅存在于model.features._modules中,而且还存在于自定义的密集层中,例如。 model.features._modules['denseblock1'][0]。对于resnet151ReLU存在于model._modules及其自定义层中,例如model._modules['layer1']

有什么方法可以找到模型中的所有ReLU吗?

1 个答案:

答案 0 :(得分:3)

一种遍历模型所有组件的更优雅的方法是使用modules() 方法:

from torch import nn

for module in self.model.modules():
  if isinstance(module, nn.ReLU):
    module.register_backward_hook(f)

如果您不希望获取所有子模块,而仅获取直接子模块,则可以考虑使用children()方法而不是modules()。您还可以使用named_modules()方法获取子模块的名称。