我已经在GPU(服务器)上训练了一个模型,并保存了state_dict。现在,我想在CPU上本地进行一些测试,因此我下载了state_dict,创建了一个模型,并希望加载该state_dict。经过研究,我找到了torch.load()函数的map_location-attribute。所以最后,我尝试通过以下方式加载state_dict:
device = torch.device('cpu')
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=4)
model.load_state_dict(torch.load('state_dict.pth', map_location=device))
但是我收到以下RuntimeError:
RuntimeError: storage has wrong size: expected 1005829391 got 393216
有人知道如何在我的CPU上加载state_dict吗?