从训练有素的火炬模型中获取预测

时间:2020-09-16 13:48:19

标签: python computer-vision pytorch transfer-learning torchvision

我正在使用转移学习来微调inception_v3模型。训练模型并保存最佳版本后,我尝试使用它为我的测试集生成预测。下面是我尝试一张图片的示例。

img_test=Image.open("img.png")

#Perform same transformations to image that the model used
transform_pipeline = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
img_test = transform_pipeline(img_test)

# I believe this is adding in the batch size of 1, but in looking around online it looked like I needed it
img = img_test.unsqueeze(0)
img = Variable(img)

    
model_ft(img)

当我执行以上操作时,我会得到

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

这似乎暗示我的模型权重在我的gpu上,而变量在cpu上,我该如何移动一个或另一个以使用它,或者引用位于相反处理器上的一个?

1 个答案:

答案 0 :(得分:1)

正如错误所述,似乎模型的输入(您的img_test)在cpu中。

在通过预先训练的模型发送图像之前,尝试将其移至cuda:

device = torch.device('cuda' if torch.cuda.is_available())
img_test = img_test.to(device)