保存和加载Pytorch模型检查点以进行推断不起作用

时间:2019-01-18 21:50:55

标签: python-3.x lstm pytorch

我有一个使用LSTM训练的模型。该模型在 GPU (在Google COLABORATORY上)上进行了训练。 我必须保存模型以进行推断;我将在 CPU 上运行。 训练后,我将模型检查点保存如下:

torch.save({'model_state_dict': model.state_dict()},'lstmmodelgpu.tar')

为了进行推断,我将模型加载为:

# model definition
vocab_size = len(vocab_to_int)+1 
output_size = 1
embedding_dim = 300
hidden_dim = 256
n_layers = 2

model = SentimentLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)

# loading model
device = torch.device('cpu')
checkpoint = torch.load('lstmmodelgpu.tar', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

但是,它引发了以下错误:

model.load_state_dict(checkpoint['model_state_dict'])
  File "workspace/envs/envdeeplearning/lib/python3.5/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SentimentLSTM:
    Missing key(s) in state_dict: "embedding.weight". 
    Unexpected key(s) in state_dict: "encoder.weight".

保存检查点时是否遗漏了什么?

1 个答案:

答案 0 :(得分:1)

这里有两件事要考虑。

  1. 您提到要在GPU上训练模型并将其用于CPU推理,因此您需要在 load 函数中添加参数 map_location 传递 torch.device('cpu')

  2. state_dict密钥不匹配(在您的输出消息中指示),这可能是由于您丢失的某些密钥或您正在加载的 state_dict 中的密钥多于您的模型而导致的。目前正在使用。为此,您必须在 load_state_dict 函数中添加值为 False 的参数 strict 。这将使方法可以忽略键的不匹配。

侧面说明:尝试将pt或pth扩展名用于检查点文件,因为这是一个约定。