我有一个使用LSTM训练的模型。该模型在 GPU (在Google COLABORATORY上)上进行了训练。 我必须保存模型以进行推断;我将在 CPU 上运行。 训练后,我将模型检查点保存如下:
torch.save({'model_state_dict': model.state_dict()},'lstmmodelgpu.tar')
为了进行推断,我将模型加载为:
# model definition
vocab_size = len(vocab_to_int)+1
output_size = 1
embedding_dim = 300
hidden_dim = 256
n_layers = 2
model = SentimentLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)
# loading model
device = torch.device('cpu')
checkpoint = torch.load('lstmmodelgpu.tar', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
但是,它引发了以下错误:
model.load_state_dict(checkpoint['model_state_dict'])
File "workspace/envs/envdeeplearning/lib/python3.5/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SentimentLSTM:
Missing key(s) in state_dict: "embedding.weight".
Unexpected key(s) in state_dict: "encoder.weight".
保存检查点时是否遗漏了什么?
答案 0 :(得分:1)
这里有两件事要考虑。
您提到要在GPU上训练模型并将其用于CPU推理,因此您需要在 load 函数中添加参数 map_location 传递 torch.device('cpu')。
state_dict密钥不匹配(在您的输出消息中指示),这可能是由于您丢失的某些密钥或您正在加载的 state_dict 中的密钥多于您的模型而导致的。目前正在使用。为此,您必须在 load_state_dict 函数中添加值为 False 的参数 strict 。这将使方法可以忽略键的不匹配。
侧面说明:尝试将pt或pth扩展名用于检查点文件,因为这是一个约定。