我正在使用fastai库(fast.ai)来训练图像分类器。 fastai创建的模型实际上是一个pytorch模型。
type(model)
<class 'torch.nn.modules.container.Sequential'>
现在,我想从pytorch中使用这个模型进行推理。到目前为止,这是我的代码:
torch.save(model,"./torch_model_v1")
the_model = torch.load("./torch_model_v1")
the_model.eval() # shows the entire network architecture
基于此处显示的示例:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#sphx-glr-beginner-data-loading-tutorial-py,我知道我需要编写自己的数据加载类,它将覆盖Dataset类中的一些函数。但是我不清楚的是我需要在测试时应用的转换?特别是,如何在测试时对图像进行标准化?
另一个问题:我在pytorch中保存和加载模型的方法很好吗?我在这里的教程中读到:http://pytorch.org/docs/master/notes/serialization.html我不建议使用我使用的方法。原因尚不清楚。
答案 0 :(得分:2)
只是澄清:the_model.eval()
不仅打印架构,还将模型设置为评估模式。
特别是,如何在测试时对图像进行标准化?
这取决于您拥有的型号。例如,对于torchvision
模块,您必须规范化输入this way。
关于如何保存/加载模型,torch.save
/ torch.load
“将对象保存/加载到磁盘文件。”
因此,如果保存the_model
,它将保存整个模型对象,包括其体系结构定义和其他一些内部方面。如果保存the_model.state_dict()
,它将仅保存包含模型状态(即参数和缓冲区)的字典。保存模型可能会以各种方式破坏代码,因此首选方法是仅保存和加载模型状态。但是,我不确定fast.ai“模型文件”是否实际上是完整模型或模型的状态。你必须检查这个,这样你才能正确加载它。