从torchvision.models
获取经过预训练的模型后,我希望将所有ReLU
实例都存储到register_backward_hook(f)
,如下所示:
for pos, module in self.model.features._modules.items():
for sub_module in module:
if isinstance(module, ReLU):
module.register_backward_hook(f)
对我来说,问题是如何在模型中找到所有ReLU
。对于densenet161
,ReLU
不仅存在于model.features._modules
中,而且还存在于自定义的密集层中,例如。 model.features._modules['denseblock1'][0]
。对于resnet151
,ReLU
存在于model._modules
及其自定义层中,例如model._modules['layer1']
。
有什么方法可以找到模型中的所有ReLU
吗?
答案 0 :(得分:3)
一种遍历模型所有组件的更优雅的方法是使用modules()
方法:
from torch import nn
for module in self.model.modules():
if isinstance(module, nn.ReLU):
module.register_backward_hook(f)
如果您不希望获取所有子模块,而仅获取直接子模块,则可以考虑使用children()
方法而不是modules()
。您还可以使用named_modules()
方法获取子模块的名称。