如何加载和使用PyTorch(.pth.tar)模型

时间:2018-08-15 10:45:55

标签: python python-3.x neural-network pytorch torch

我对Torch不太熟悉,主要使用Tensorflow。但是,我需要使用在Torch中经过重新训练的初始模型。由于为我的特定应用重新训练初始模型需要大量的计算资源,因此我想使用已经重新训练的模型。

此模型另存为.pth.tar文件。

我希望能够首先加载此模型。到目前为止,我已经能够确定必须使用以下内容:

model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')

这似乎行得通,因为print(model)会打印出大量数字和其他值,我认为这是权重偏差的值。

在此之后,我需要能够使用它对图像进行分类。我还没弄清楚。我该如何格式化图像?是否应将图像转换为数组?之后,如何将输入数据传递到网络?

1 个答案:

答案 0 :(得分:0)

您基本上需要执行与tensorflow中相同的操作。也就是说,当您存储网络时,将仅存储参数(即网络中的可训练对象),而不是“胶水”,这就是使用训练模型所需的全部逻辑。 因此,如果您有一个.pth.tar文件,则可以加载它,从而覆盖已经定义的模型的参数值。

这意味着保存/加载模型的一般过程如下:

  • 写下您的网络定义(即您的nn.Module对象)
  • 以您想要的方式训练或以其他方式更改网络参数
  • 使用torch.save保存参数
  • 当您要使用该网络时,请使用nn.Module对象的相同定义来首先实例化pytorch网络
  • 然后使用torch.load
  • 覆盖网络参数的值

以下是有关如何执行此操作的讨论:pytorch forums

这是一个超短的mwe:

# to store
torch.save({
    'state_dict': model.state_dict(),
    'optimizer' : optimizer.state_dict(),
}, 'filename.pth.tar')

# to load
checkpoint = torch.load('filename.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])