在 pytorch 神经网络中初始化权重

时间:2021-03-28 21:58:09

标签: python-3.x neural-network pytorch torch

我创建了这个神经网络:

class _netD(nn.Module):
    def __init__(self, num_classes=1, nc=1, ndf=64):
        super(_netD, self).__init__()
        self.num_classes = num_classes
        # nc is number of channels
        # num_classes is number of classes
        # ndf is the number of output channel at the first layer
        self.main = nn.Sequential(
            # input is (nc) x 28 x 28
            # conv2D(in_channels, out_channels, kernelsize, stride, padding)
            nn.Conv2d(nc, ndf , 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 14 x 14
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 7 x 7
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 3 x 3
            nn.Conv2d(ndf * 4, ndf * 8, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 2 x 2
            nn.Conv2d(ndf * 8, num_classes, 2, 1, 0, bias=False),
            # out size = batch x num_classes x 1 x 1
        )

        if self.num_classes == 1:
            self.main.add_module('prob', nn.Sigmoid())
            # output = probability
        else:
            pass
            # output = scores

    def forward(self, input):

        output = self.main(input)
        return output.view(input.size(0), self.num_classes).squeeze(1)

我想遍历不同的层并根据层的类型应用权重初始化。我正在尝试执行以下操作:

D = _netD()

for name, param in D.named_parameters():
    if type(param) == nn.Conv2d:
      param.weight.normal_(...)

但这行不通。你能帮我吗?

谢谢

1 个答案:

答案 0 :(得分:1)

type(param) 只会为模型中的任何类型的权重或数据返回称为 parameter 的实际数据类型。由于 named_parameters() 在基于 nn.sequential 的模型上使用时也不返回任何有用的名称,因此您需要查看模块以查看哪些层与 nn.Conv2d 类特别相关,使用isinstance 如此:

for layer in D.modules():
    if isinstance(layer, nn.Conv2d):
         layer.weight.data.normal_(...)

或者,按照 Soumith Chintala 本人推荐的方式,实际上只是循环遍历您的主模块本身:

for L,layer in D.main:
    if isisntance(layer,nn.Conv2d):
         layer.weight.data.normal_(..)

我实际上更喜欢第一个,因为您不必指定确切的 nn.sequential 模块本身,并且会搜索模型中所有可能的模块,但任何一个都应该为您完成这项工作。