在pytorch中显示分类错误的图像

时间:2020-06-23 14:26:46

标签: numpy matplotlib pytorch tensorboard

我是pytorch和numpy的新手,所以这可能是一个愚蠢的问题。我希望看到一些图像被我的网络分类错误,并带有正确的标签和预测的标签。这是我的代码

.draggable_to

提前谢谢

3 个答案:

答案 0 :(得分:0)

至少有两种方法可以完成此操作。

一种是存储在评估(运行测试数据)时未正确分类的图像,并将其绘制出来。显示为here

另一种方法是使用TensorBoard。我认为这非常优雅,您可以找到关于它的全面指南here

答案 1 :(得分:0)

def test(dataset, dataloader):
    net.eval()
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch[0]
            label=batch[1]
            inputs = inputs.to(device, non_blocking=True)
            outputs = net(inputs)
            predictions = torch.argmax(outputs, dim=1)
            for sampleno in range(batch[0].shape[0]):
                if(label[sampleno]!=predictions[sampleno]):
                    print("Actual Lable")
                    print(label[sampleno])
                    print("Predicted Label")
                    print(predictions[sampleno])
                    showimg(inputs[sampleno].cpu())
            return predictions

您可以这样编写showing()函数

def showimg(model):
    model=np.reshape(model.numpy(),[28,28]) # For 1D Vector
    
    #If you normalize the image then use Next three-line
    #Otherwise skip that
    mean=np.array([0.485, 0.456, 0.406] )
    std=np.array([0.229, 0.224, 0.225])
    model=(model*std+mean)
    


    #print(model)

    cv2.imshow("ABC", model)
    
    #waits for user to press any key
    #(this is necessary to avoid Python kernel form crashing)
    cv2.waitKey(0)

    #closing all open windows
    cv2.destroyAllWindows()

答案 2 :(得分:0)

我收到此错误,不知道这意味着什么

ValueError                                Traceback (most recent call last)
 in 
    288 
    289         # test on validation
--> 290         predictions = test(dataset_valid, dataloader_valid)
    291         accuracy_valid = 100. * predictions.eq(dataset_valid.dataset.targets[dataset_valid.indices].to(device)).sum().float() / len(dataset_valid)
    292 

 in test(dataset, dataloader)
    236                     print("Predicted Label")
    237                     print(predictions[sampleno])
--> 238                     showimages(inputs[sampleno].cpu())
    239             return predictions
    240 

 in showimages(model)
    240 
    241 def showimages(model):
--> 242     model=np.transpose(model.numpy(),(1,2,0))
    243 
    244     

<__array_function__ internals> in transpose(*args, **kwargs)

~/.local/lib/python3.7/site-packages/numpy/core/fromnumeric.py in transpose(a, axes)
    649 
    650     """
--> 651     return _wrapfunc(a, 'transpose', axes)
    652 
    653 

~/.local/lib/python3.7/site-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     59 
     60     try:
---> 61         return bound(*args, **kwds)
     62     except TypeError:
     63         # A TypeError occurs if the object does have such a method in its

ValueError: axes don't match array