在火炬中保存带有重量衰减的模型

时间:2020-04-09 15:19:20

标签: python pytorch

这是我的模特:

# basic LeNet5 network
class LeNet5_mode0 (nn.Module) : 

  # constructor 
  def __init__(self):
    super(LeNet5_mode0, self).__init__() # call to super constructor

    # define layers
    # 6 @ 28x28
    self.conv1 = nn.Sequential(
        # Lenet's first conv layer is 3x32x32, squeeze color channels into 1 and pad 2
        nn.Conv2d(in_channels = 1, out_channels = 6, kernel_size = 5, stride = 1, padding = 2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 2, stride = 2)
        )

    # 16 @ 10x10
    self.conv2 = nn.Sequential(
        nn.Conv2d(in_channels = 6, out_channels = 16, kernel_size = 5, stride = 1, padding = 0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size =2, stride = 2)
        )

    self.fc1 = nn.Sequential(
        nn.Linear(in_features = 16*5*5, out_features = 120),
        nn.ReLU()
        )

    self.fc2 = nn.Sequential(
        nn.Linear(in_features = 120, out_features = 84),
        nn.ReLU()
        )
    self.classifier = nn.Sequential(
        nn.Linear(in_features = 84,out_features = 10),
        nn.Softmax(dim = 1) # dim =1 meaning do softmax on the colums of 84x10
        )

  # define forward function
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(-1, 16*5*5) # reshape the tensor to [-1,16*5*5]
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.classifier(x)
    return x

我用:p训练了一次该模型

criterion = nn.CrossEntropyLoss() # aka, LogLoss
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,10,15], gamma=0.5)

,然后使用

保存
torch.save(model.state_dict(), savepath)

并使用

加载
model.load_state_dict(torch.load(loadpath))

到目前为止没有问题。但是当我将优化器更改为

optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay = 0.0005)

并使用相同的保存和加载方法

我收到以下错误:

in loading state_dict for LeNet5_mode0:
    Unexpected key(s) in state_dict: "conv1.1.weight", "conv1.1.bias", "conv1.1.running_mean", "conv1.1.running_var", "conv1.1.num_batches_tracked", "conv2.1.weight", "conv2.1.bias", "conv2.1.running_mean", "conv2.1.running_var", "conv2.1.num_batches_tracked".

如何解决?为什么不同的优化器对受过训练的网络的保存有这种影响?

0 个答案:

没有答案