我想通过PyTorch中的Transfer Learning在图像分类中显示的每个图片输出下显示准确性

时间:2020-01-14 10:26:30

标签: python image classification pytorch data-visualization

我正在跟踪以下链接:https://stackabuse.com/image-classification-with-transfer-learning-and-pytorch/#settingupapretrainedmodel

但是我是新手。请告诉我如何在图像下显示精度值。

1 个答案:

答案 0 :(得分:0)

该示例中使用的模型返回形状(批大小,类)的对数张量。假设您所说的“准确性值”是具有最大概率的类别的预测概率,那么您需要做的是首先通过获取模型输出的SoftMax来计算您的概率,从而给出每个图像的预测概率在您的批次中。他们的visualize_model函数看起来类似于以下内容,尽管我尚未对其进行测试。

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_handeled = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            probabilities = nn.functional.softmax(outputs, dim=-1) # compute probabilities
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_handeled += 1
                ax = plt.subplot(num_images//2, 2, images_handeled)
                ax.axis('off')
                ax.set_title('predicted: {}, probability: {}'.format(class_names[preds[j]], probabilities[preds[j]])) # add predicted class probability
                imshow(inputs.cpu().data[j])

                if images_handeled == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

或者您是说总体分类的准确性?