使用pytorch模型进行推理

时间:2018-03-11 14:58:45

标签: pytorch

我正在使用fastai库(fast.ai)来训练图像分类器。 fastai创建的模型实际上是一个pytorch模型。

type(model)
<class 'torch.nn.modules.container.Sequential'>

现在,我想从pytorch中使用这个模型进行推理。到目前为止,这是我的代码:

torch.save(model,"./torch_model_v1")
the_model = torch.load("./torch_model_v1")
the_model.eval() # shows the entire network architecture

基于此处显示的示例:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#sphx-glr-beginner-data-loading-tutorial-py,我知道我需要编写自己的数据加载类,它将覆盖Dataset类中的一些函数。但是我不清楚的是我需要在测试时应用的转换?特别是,如何在测试时对图像进行标准化?

另一个问题:我在pytorch中保存和加载模型的方法很好吗?我在这里的教程中读到:http://pytorch.org/docs/master/notes/serialization.html我不建议使用我使用的方法。原因尚不清楚。

1 个答案:

答案 0 :(得分:2)

只是澄清:the_model.eval()不仅打印架构,还将模型设置为评估模式

  

特别是,如何在测试时对图像进行标准化?

这取决于您拥有的型号。例如,对于torchvision模块,您必须规范化输入this way

关于如何保存/加载模型,torch.save / torch.load“将对象保存/加载到磁盘文件。”

因此,如果保存the_model,它将保存整个模型对象,包括其体系结构定义和其他一些内部方面。如果保存the_model.state_dict(),它将仅保存包含模型状态(即参数和缓冲区)的字典。保存模型可能会以各种方式破坏代码,因此首选方法是仅保存和加载模型状态。但是,我不确定fast.ai“模型文件”是否实际上是完整模型或模型的状态。你必须检查这个,这样你才能正确加载它。