Pytorch:从CPU上的GPU加载state_dict时存储空间大小错误

时间:2019-12-04 20:59:41

标签: pytorch torch

我已经在GPU(服务器)上训练了一个模型,并保存了state_dict。现在,我想在CPU上本地进行一些测试,因此我下载了state_dict,创建了一个模型,并希望加载该state_dict。经过研究,我找到了torch.load()函数的map_location-attribute。所以最后,我尝试通过以下方式加载state_dict:

device = torch.device('cpu')
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=4)
model.load_state_dict(torch.load('state_dict.pth', map_location=device))

但是我收到以下RuntimeError:

RuntimeError: storage has wrong size: expected 1005829391 got 393216

有人知道如何在我的CPU上加载state_dict吗?

0 个答案:

没有答案