如何在PyTorch中保存模型架构?

时间:2020-01-05 00:33:29

标签: pytorch

我知道我可以通过torch.save(model.state_dict(), FILE)torch.save(model, FILE)保存模型。但是他们两个都没有保存模型的体系结构。

那么我们如何像在Tensorflow中创建.pb文件那样在PyTorch中保存模型的体系结构?我想对模型进行不同的调整。如果无法保存模型的体系结构,是否有比每次都复制整个类定义并创建新类更好的方法?

3 个答案:

答案 0 :(得分:4)

仅保存所有参数(state_dict)和所有模块是不够的,因为有一些操作可控制张量,但仅反映在特定实现的实际 code 中(例如reshapeing in ResNet)。

此外,网络可能没有固定的和预先确定的计算图:您可以想到具有分支或循环(重复发生)的网络。

因此,您必须保存实际代码。

或者,如果网络中没有分支/循环,则可以保存计算图,例如,参见this post

您还应该考虑使用onnx导出模型,并具有可以捕获训练后的权重和计算图的表示形式。

答案 1 :(得分:1)

您可以参考[this] [1]文章以了解如何保存分类器。要对模型进行调整,您可以做的是创建一个新模型,该模型是现有模型的子模型。

class newModel( oldModelClass):
    def __init__(self):
        super(newModel, self).__init__()

使用此设置,newModel具有oldModelClass的所有层以及前进功能。如果需要进行调整,可以在__init__函数中定义新图层,然后编写一个新的forward函数来对其进行定义。

答案 2 :(得分:1)

关于实际问题:

那么我们如何像在Tensorflow中创建.pb文件那样在PyTorch中保存模型的体系结构?

答案是:您不能

有没有什么方法可以在不声明类定义的情况下加载经过训练的模型? 我希望模型架构以及参数都可以加载。

不,您必须先加载类定义,这是python酸洗限制。

https://discuss.pytorch.org/t/how-to-save-load-torch-models/718/11

不过,此PyTorch帖子中列出了其他选项(可能您已经看到了其中的大多数):

https://pytorch.org/tutorials/beginner/saving_loading_models.html