如何在pytorch自定义模型的模块类中添加参数?

时间:2019-12-08 09:54:07

标签: deep-learning pytorch

我试图找到答案,但是找不到。

我使用pytorch创建了一个自定义的深度学习模型。例如,

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.nn_layers = nn.ModuleList()
        self.layer = nn.Linear(2,3).double()
        torch.nn.init.xavier_normal_(self.layer.weight)

        self.bias = torch.nn.Parameter(torch.randn(3))

        self.nn_layers.append(self.layer)

    def forward(self, x):
        activation = torch.tanh
        output = activation(self.layer(x)) + self.bias

        return output

如果我打印

model = Net()
print(list(model.parameters()))

它不包含model.bias,因此 Optimizer = Optimizer.Adam(model.parameters())不会更新model.bias。 我该如何处理? 谢谢!

1 个答案:

答案 0 :(得分:1)

您需要register您的参数:

self.register_parameter(name='bias', param=torch.nn.Parameter(torch.randn(3)))