在PyTorch中加载模型时出错

时间:2018-02-15 13:03:46

标签: python pytorch

我有以下代码段

...EPoSDb.accdb;Jet OLEDB...

当我运行脚本时,它会抛出如下错误:

from train import predict
import random
import torch


ann=torch.load('ann.pt') #importing trained model


while True:
      k=raw_input("User:")
      intent,top_value,top_index = predict(str(k),ann)
      print(intent)

我的 ann.pt 文件与我的脚本位于同一文件夹中。 请帮我识别修复错误并加载模型。 提前谢谢。

1 个答案:

答案 0 :(得分:0)

当试图保存参数和模型时,pytorch会腌制参数,但只存储模型类的路径。例如,更改树结构或重构可能会破坏加载。 因此,作为documentation points out,建议不要使用保存/加载参数:

  

...序列化数据绑定到特定的类和使用的确切目录结构,因此当在其他项目中使用时,或者在一些严重的重构之后,它可以以各种方式中断。

如需更多帮助,请务必显示保存代码。