在PyTorch中保存训练模型的最佳方法?

时间:2017-03-09 19:06:41

标签: python serialization deep-learning pytorch tensor

我一直在寻找在PyTorch中保存经过训练的模型的替代方法。到目前为止,我找到了两种选择。

  1. torch.save()保存模型,torch.load()加载模型。
  2. model.state_dict()保存经过培训的模型,model.load_state_dict()加载已保存的模型。
  3. 我已经看到了这个discussion,其中建议采用方法2而不是方法1。

    我的问题是,为什么第二种方法更受欢迎?仅仅因为torch.nn模块具有这两个功能而且我们被鼓励使用它们吗?

7 个答案:

答案 0 :(得分:129)

我在他们的github回购中找到this page,我只是在这里粘贴内容。

保存模型的推荐方法

序列化和恢复模型有两种主要方法。

第一个(推荐)保存并仅加载模型参数:

torch.save(the_model.state_dict(), PATH)

然后:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

torch.save(the_model, PATH)

然后:

the_model = torch.load(PATH)

但是在这种情况下,序列化数据绑定到特定类 以及使用的确切目录结构,因此它可以以各种方式中断 用于其他项目,或经过一些严重的重构后。

答案 1 :(得分:80)

这取决于你想做什么。

案例#1:保存模型以自行使用它进行推理:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为您通常拥有BatchNormDropout图层,默认情况下在构建时处于列车模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

案例#2:保存模型以便稍后恢复培训:如果您需要继续训练即将保存的模型,则需要保存的不仅仅是模型。您还需要保存优化器,时期,分数等的状态。您可以这样做:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

要恢复训练,您可以执行以下操作:state = torch.load(filepath),然后恢复每个对象的状态,如下所示:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

由于您正在恢复培训,因此在加载时恢复状态后,请勿致电model.eval()

案例#3:其他人无法访问您的代码时使用的模型: 在Tensorflow中,您可以创建一个.pb文件,用于定义模型的体系结构和权重。这非常方便,特别是在使用Tensorflow serve时。在Pytorch中执行此操作的等效方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

这种方式仍然不是防弹,因为火炬仍然经历了很多变化,我不推荐它。

答案 2 :(得分:6)

如果您要保存模型并希望以后再继续训练,则:

单个GPU: 保存:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

加载:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

多个GPU: 保存

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

加载:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU

答案 3 :(得分:2)

Saving locally

您保存模型的方式取决于您以后希望如何访问它。如果您可以调用 model 类的新实例,那么您需要做的就是使用 model.state_dict() 保存/加载模型的权重:

# Save:
torch.save(old_model.state_dict(), PATH)

# Load:
new_model = TheModelClass(*args, **kwargs)
new_model.load_state_dict(torch.load(PATH))

如果你不能因为任何原因(或者更喜欢更简单的语法),那么你可以用 torch.save() 保存整个模型(实际上是对定义模型的文件的引用,以及它的 state_dict):

# Save:
torch.save(old_model, PATH)

# Load:
new_model = torch.load(PATH)

但由于这是对定义模型类的文件位置的引用,因此除非这些文件也移植到相同的目录结构中,否则此代码不可移植。

保存到云端 - TorchHub

如果您希望您的模型具有便携性,您可以使用 torch.hub 轻松地将其导入。如果您将适当定义的 hubconf.py 文件添加到 github 存储库,则可以从 PyTorch 中轻松调用该文件,以使用户能够加载带/不带权重的模型:

hubconf.py (github.com/repo_owner/repo_name)

dependencies = ['torch']
from my_module import mymodel as _mymodel

def mymodel(pretrained=False, **kwargs):
    return _mymodel(pretrained=pretrained, **kwargs)

加载模型:

new_model = torch.hub.load('repo_owner/repo_name', 'mymodel')
new_model_pretrained = torch.hub.load('repo_owner/repo_name', 'mymodel', pretrained=True)

答案 4 :(得分:0)

pickle Python库实现了二进制协议,用于对Python对象进行序列化和反序列化。

当您import torch(或使用PyTorch)时,它将为您import pickle,并且您无需直接调用pickle.dump()pickle.load(),它们是保存和加载对象的方法。

实际上,torch.save()torch.load()会为您包装pickle.dump()pickle.load()

提到的另一个state_dict答案仅需多加说明。

PyTorch内部有什么state_dict? 实际上有两个state_dict

PyTorch模型为torch.nn.Module,调用了model.parameters()以获取可学习的参数(w和b)。 这些可学习的参数一旦随机设置,将随着我们的学习而随着时间更新。 可学习的参数是第一个state_dict

第二个state_dict是优化程序状态字典。优化器也是模型的一部分。您还记得优化器用于改善我们的可学习参数。但是优化器state_dict是固定的。没什么可学的。

由于state_dict对象是Python字典,因此可以轻松地保存,更新,更改和还原它们,从而为PyTorch模型和优化器增加了很多模块化。

让我们创建一个超级简单的模型来解释这一点:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

此代码将输出以下内容:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

请注意,这是最小模型。您可以尝试添加顺序堆栈

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

请注意,只有具有可学习参数的层(卷积层,线性层等)和已注册的缓冲区(batchnorm层)才在模型的state_dict中具有条目。

不可学习的事物属于优化器对象state_dict,其中包含有关优化器状态以及所用超参数的信息。

故事的其余部分是相同的;在推论阶段(这是我们训练后使用模型的阶段)进行预测;我们会根据所学的参数进行预测。因此,为了进行推断,我们只需要保存参数model.state_dict()

torch.save(model.state_dict(), filepath)

并在以后使用     model.load_state_dict(torch.load(filepath))     model.eval()

注意:不要忘记最后一行model.eval(),这在加载模型之后至关重要。

也不要尝试保存torch.save(model.parameters(), filepath)model.parameters()只是生成器对象。

另一方面,torch.save(model, filepath)保存模型对象本身,但是请记住,模型没有优化程序的state_dict。查看@Jadiel de Armas的其他出色答案,以保存优化程序的状态指令。

答案 5 :(得分:0)

常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。

保存/加载整个模型 保存:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

加载:

模型类必须在某处定义

model = torch.load(PATH)
model.eval()

答案 6 :(得分:0)

这几天什么都写在官方教程里了: https://pytorch.org/tutorials/beginner/saving_loading_models.html

关于如何保存和保存内容,您有多种选择,所有内容都在该教程中进行了说明。

相关问题