如何在pytorch模型中初始化权重

时间:2020-06-07 14:17:34

标签: python deep-learning pytorch

我这里有一个很简单的问题。 我刚刚完成了重新配置网络的操作,将nn.Upsample替换为下面代码中所示的upConv顺序容器。我已经通过运行summary(UNetPP, (3, 128, 128))验证了所有内容,并且运行没有问题。

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class blockUNetPP(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)


    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out

class upConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.upc = nn.Sequential(
                                 nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                 nn.Conv2d(in_ch, out_ch*2, 3, stride=1, padding=1),
                                 nn.BatchNorm2d(out_ch*2),
                                 nn.ReLU(inplace=True)
                                 )
    def forward(self, x):
        out = self.upc(x)
        return out

我的问题是,当我尝试开始训练模型时,出现以下问题:

Traceback (most recent call last):
  File "runTrain.py", line 90, in <module>
    netG.apply(weights_init)
  File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 289, in apply
    module.apply(fn)
  File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 290, in apply
    fn(self)
  File "D:\Thesis Models\Deep_learning_models\UNet\train\NetC.py", line 8, in weights_init
    m.weight.data.normal_(0.0, 0.02)
  File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 594, in __getattr__
    type(self).__name__, name))
AttributeError: 'upConv' object has no attribute 'weight'

我查看了solutions,这建议循环遍历容器模块,但是我已经使用weights_init(m)进行了此操作。有人可以解释我当前设置的问题吗?

1 个答案:

答案 0 :(得分:1)

您正在确定如何初始化权重,方法是检查类名是否包含classname.find('Conv') Conv 。您的类的名称为 upConv ,其中包括 Conv ,因此您尝试初始化其属性.weight,但该属性不存在。

重命名您的类或使条件更严格,例如classname.find('Conv2d')。最严格的方法是检查它是否为nn.Conv2d的实例,而不是查看类的名称。

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        m.weight.data.normal_(0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)