已经训练了我的编码器-解码器模型,并使用:
保存了model_state = {
'encoder': encoder,
'encoder_optimizer': encoder_optimizer,
'decoder': decoder,
'decoder_optimizer': decoder_optimizer
}
torch.save(model_state, "best_model.pth.tar")
当我独立使用模型时,这可以很好地工作,但是当我尝试在另一个应用程序中使用模型时,这会给我错误。因此,我尝试加载模型并将编码器和解码器另存为state_dicts。这适用于我的编码器,但是当我尝试时:
checkpoint = torch.load(path_to_model, map_location=torch.device("cpu"))
decoder = checkpoint['decoder']
decoder = decoder.to(device)
encoder = checkpoint['encoder']
encoder = encoder.to(device)
torch.save(encoder.state_dict(), 'encoder.dict')
torch.save(decoder.state_dict(), 'decoder.dict')
它在torch.save(decoder.state_dict(), 'decoder.dict')
上失败,并且出现错误:
File "<stdin>", line 1, in <module>
File "caption.py", line 31, in load_maps
torch.save(decoder.state_dict(), 'decoder.dict')
File "/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 695, in state_dict
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
File "/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 695, in state_dict
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
File "/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 696, in state_dict
for hook in self._state_dict_hooks.values():
File "/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 585, in __getattr__
type(self).__name__, name))
AttributeError: 'Softmax' object has no attribute '_state_dict_hooks'
有没有一种方法可以绕过此错误或重新创建state_dict
而不必重新训练我的模型?我不明白为什么要保存整个模型却无法从中获得state_dict
,这大概是模型的一部分。
这是调用for m in decoder.modules(): print(m)
的输出:
DecoderWithAttention(
(attention): Attention(
(encoder_att): Linear(in_features=2048, out_features=512, bias=True)
(decoder_att): Linear(in_features=512, out_features=512, bias=True)
(full_att): Linear(in_features=512, out_features=1, bias=True)
(relu): ReLU()
(softmax): Softmax(dim=1)
)
(embedding): Embedding(9490, 512)
(dropout): Dropout(p=0.5, inplace=False)
(decode_step): LSTMCell(2560, 512, bias=1)
(init_h): Linear(in_features=2048, out_features=512, bias=True)
(init_c): Linear(in_features=2048, out_features=512, bias=True)
(f_beta): Linear(in_features=512, out_features=2048, bias=True)
(sigmoid): Sigmoid()
(fc): Linear(in_features=512, out_features=9490, bias=True)
)
Attention(
(encoder_att): Linear(in_features=2048, out_features=512, bias=True)
(decoder_att): Linear(in_features=512, out_features=512, bias=True)
(full_att): Linear(in_features=512, out_features=1, bias=True)
(relu): ReLU()
(softmax): Softmax(dim=1)
)
Linear(in_features=2048, out_features=512, bias=True)
Linear(in_features=512, out_features=512, bias=True)
Linear(in_features=512, out_features=1, bias=True)
ReLU()
Softmax(dim=1)
Embedding(9490, 512)
Dropout(p=0.5, inplace=False)
LSTMCell(2560, 512, bias=1)
Linear(in_features=2048, out_features=512, bias=True)
Linear(in_features=2048, out_features=512, bias=True)
Linear(in_features=512, out_features=2048, bias=True)
Sigmoid()
Linear(in_features=512, out_features=9490, bias=True)
答案 0 :(得分:0)
尝试这样保存模型。
torch.save({'state_dict': decoder.state_dict()}, 'decoder.pth.tar')