pytorch:无法加载CNN模型并做预测TypeError:'collections.OrderedDict'对象不可调用

时间:2018-01-24 10:03:24

标签: python tensorflow pytorch

我使用MNIST数据集训练了CNN模型,现在想要预测图像的分类,其中包含数字3。

但是当我试图用这个CNN来预测时,pytorch给了我这个错误:

TypeError: 'collections.OrderedDict' object is not callable

这就是我写的:

cnn = torch.load("/usr/prakt/w153/Desktop/score_detector.pkl")
img = scipy.ndimage.imread("/usr/prakt/w153/Desktop/resize_num_three.png")
test_x = Variable(torch.unsqueeze(torch.FloatTensor(img), dim=1), volatile=True).type(torch.FloatTensor).cuda()
test_output, last_layer = cnn(test_x)
pred = torch.max(test_output, 1)[1].cuda().data.squeeze()
print(pred)

这里有一些解释: img是预测图像,大小为28 * 28 score_detector.pkl是受过训练的CNN模型

任何帮助将不胜感激!

2 个答案:

答案 0 :(得分:1)

我非常确定score_detector.pkl实际上是一个state_dict,而不是模型本身。您需要首先实例化模型然后加载state_dict,因此您的第一行应替换为以下内容:

cnn = MyModel()
cnn.load_state_dict("/usr/prakt/w153/Desktop/score_detector.pkl")

然后其余的应该工作。 有关详细信息,请参阅this link

答案 1 :(得分:0)

实际上,您正在加载state_dict而不是模型本身。

保存模型如下:

torch.save(model.state_dict(), 'model_state.pth')

要加载模型状态,您首先需要初始化模型,然后加载状态

model = Model()
model.load_state_dict(torch.load('model_state.pth'))

如果您在GPU上训练了模型,但又想将模型加载到没有CUDA的笔记本电脑上,那么您需要再添加一个参数

model.load_state_dict(torch.load('model_state.pth', map_location='cpu'))