在pytorch中使用权重的自定义初始值创建新模型

时间:2019-12-24 10:36:55

标签: python neural-network pytorch

我是pytorch的新手,我想了解如何为网络的第一个隐藏层设置初始权重。我的解释要好一些:我的网络是一个非常简单的一层MLP,具有784个输入值和10个输出值

 class Classifier(nn.Module):
        def __init__(self):
          super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
        # Dropout module with 0.2 drop probability
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        # make sure input tensor is flattened
        # x = x.view(x.shape[0], -1)

        # Now with dropout
        x = self.dropout(F.relu(self.fc1(x)))

        # output so no dropout here
        x = F.log_softmax(self.fc2(x), dim=1)

        return x 

,到目前为止,我有一个形状为(128,784)的numpy矩阵,其中包含我想要的fc1权重值。如何使用矩阵中包含的值来初始化第一层的权重?

在其他答案中在线搜索,我发现必须定义权重的初始化函数,例如

def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv2d') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

但是我听不懂代码

1 个答案:

答案 0 :(得分:0)

您只需使用torch.nn.Parameter()即可为网络层分配自定义权重。

根据您的情况-

model.fc1.weight = torch.nn.Parameter(custom_weight)

torch.nn.Parameter:一种被视为模块参数的Tensor。

例如

# Classifier model
model = Classifier()

# your custom weight, here taking randam
custom_weight = torch.rand(model.fc1.weight.shape)
custom_weight.shape
torch.Size([128, 784])

# before assign custom weight
print(model.fc1.weight)
Parameter containing:
tensor([[ 1.6920e-02,  4.6515e-03, -1.0214e-02,  ..., -7.6517e-03,
          2.3892e-02, -8.8965e-03],
        ...,
        [-2.3137e-02,  5.8483e-03,  4.4392e-03,  ..., -1.6159e-02,
          7.9369e-03, -7.7326e-03]])

# assign custom weight to first layer
model.fc1.weight = torch.nn.Parameter(custom_weight)

# after assign custom weight
model.fc1.weight
Parameter containing:
tensor([[ 0.1724,  0.7513,  0.8454,  ...,  0.8780,  0.5330,  0.5847],
        [ 0.8500,  0.7687,  0.3371,  ...,  0.7464,  0.1503,  0.7720],
        [ 0.8514,  0.6530,  0.6261,  ...,  0.7867,  0.9312,  0.3890],
        ...,
        [ 0.5426,  0.7655,  0.1191,  ...,  0.4343,  0.2500,  0.6207],
        [ 0.2310,  0.4260,  0.4138,  ...,  0.1168,  0.5946,  0.2505],
        [ 0.4220,  0.5500,  0.6282,  ...,  0.5921,  0.7953,  0.9997]])