火炬加载可以减轻重量,但不起作用

时间:2019-11-15 11:58:08

标签: computer-vision pytorch

我有这个代码。我在每个时期之后都保存权重,并且代码将其保存。但是,当我加载权重时,损耗值从初始损耗值开始,这意味着加载以某种方式失败。

net = torchvision.models.resnet18(pretrained=True)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 136)

def train():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    criterion = L1Loss(reduction='sum')

    lr = 0.0000001
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=0.0005)

    net.to(device)

    state = torch.load('face2.txt')
    net.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])

    for epoch in range(int(0), 200000):
        for batch, data in enumerate(trainloader, 0):
            torch.cuda.empty_cache()

            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs).reshape(-1, 68, 2)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        state = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict(),
        }

        torch.save(state, 'face2.txt')

if __name__ == '__main__':
    train()   

初始损失为50k以上,几千年后损失为50-60。现在,当我重新运行代码时,我希望它从接近损失的值开始,但是又从大约50k开始。

1 个答案:

答案 0 :(得分:0)

您编写的代码:

net = torchvision.models.resnet18(pretrained=True)

表示您使用相同的网络-重新训练的resnet18重新开始。相反,您应该加载最后一个状态(如果存在),这将解决您的问题。

我会稍微更新您的符号:

state = {
    'epoch': epoch,
    'model_state_dict': net.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

可学习的参数是第一个state_dict(模型状态dict)。

第二个state_dict是优化器状态dict。您还记得优化器用于改善我们的可学习参数。但是优化器state_dict是固定的。没什么可学的。

您的代码在某些时候应该看起来像:

model.load_state_dict(state['model_state_dict'])
optimizer.load_state_dict(state['optimizer_state_dict'])