pytorch nn.module如何保存子模块

时间:2017-11-16 07:57:46

标签: python pytorch

我对pytorch nn.module的工作方式有一些疑问

import torch
import torch.nn as nn



class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.sub_module = nn.Linear(10, 5)
        self.value = 3

net = Net()
print(net.__dict__)

输出

{'_modules': OrderedDict([('sub_module', Linear (10 -> 5))]),  'value': 3, ...}

我知道类的每个属性都应存储在 __ dict __ 中,为什么value(一个int值)在其中,但是sub_module(一个nn.Module)不是,而sub_module是存储在 _modules

我读了nn.Module实现的代码,但我没想到。有人有任何想法吗?

谢谢!!

1 个答案:

答案 0 :(得分:1)

我会尽量保持简单。

每次在类Net中创建新项目时,例如:self.sub_module = nn.Linear(10, 5)它会调用其父类的方法__setattr__,在本例中为nn.Module。然后,在__setattr__方法内,参数存储到它们所属的字典中。在这种情况下,由于nn.Linear是一个模块,因此它存储在_modules字典中。

以下是在Modulehttps://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L389

中执行此操作的代码段