我知道我可以通过torch.save(model.state_dict(), FILE)
或torch.save(model, FILE)
保存模型。但是他们两个都没有保存模型的体系结构。
那么我们如何像在Tensorflow中创建.pb
文件那样在PyTorch中保存模型的体系结构?我想对模型进行不同的调整。如果无法保存模型的体系结构,是否有比每次都复制整个类定义并创建新类更好的方法?
答案 0 :(得分:4)
仅保存所有参数(state_dict
)和所有模块是不够的,因为有一些操作可控制张量,但仅反映在特定实现的实际 code 中(例如reshape
ing in ResNet)。
此外,网络可能没有固定的和预先确定的计算图:您可以想到具有分支或循环(重复发生)的网络。
因此,您必须保存实际代码。
或者,如果网络中没有分支/循环,则可以保存计算图,例如,参见this post。
您还应该考虑使用onnx
导出模型,并具有可以捕获训练后的权重和计算图的表示形式。
答案 1 :(得分:1)
class newModel( oldModelClass):
def __init__(self):
super(newModel, self).__init__()
使用此设置,newModel具有oldModelClass
的所有层以及前进功能。如果需要进行调整,可以在__init__
函数中定义新图层,然后编写一个新的forward函数来对其进行定义。
答案 2 :(得分:1)
关于实际问题:
那么我们如何像在Tensorflow中创建.pb文件那样在PyTorch中保存模型的体系结构?
答案是:您不能
有没有什么方法可以在不声明类定义的情况下加载经过训练的模型? 我希望模型架构以及参数都可以加载。
不,您必须先加载类定义,这是python酸洗限制。
https://discuss.pytorch.org/t/how-to-save-load-torch-models/718/11
不过,此PyTorch帖子中列出了其他选项(可能您已经看到了其中的大多数):
https://pytorch.org/tutorials/beginner/saving_loading_models.html