我是pytorch和numpy的新手,所以这可能是一个愚蠢的问题。我希望看到一些图像被我的网络分类错误,并带有正确的标签和预测的标签。这是我的代码
.draggable_to
提前谢谢
答案 0 :(得分:0)
至少有两种方法可以完成此操作。
一种是存储在评估(运行测试数据)时未正确分类的图像,并将其绘制出来。显示为here
另一种方法是使用TensorBoard。我认为这非常优雅,您可以找到关于它的全面指南here
答案 1 :(得分:0)
def test(dataset, dataloader):
net.eval()
with torch.no_grad():
for batch in dataloader:
inputs = batch[0]
label=batch[1]
inputs = inputs.to(device, non_blocking=True)
outputs = net(inputs)
predictions = torch.argmax(outputs, dim=1)
for sampleno in range(batch[0].shape[0]):
if(label[sampleno]!=predictions[sampleno]):
print("Actual Lable")
print(label[sampleno])
print("Predicted Label")
print(predictions[sampleno])
showimg(inputs[sampleno].cpu())
return predictions
您可以这样编写showing()函数
def showimg(model):
model=np.reshape(model.numpy(),[28,28]) # For 1D Vector
#If you normalize the image then use Next three-line
#Otherwise skip that
mean=np.array([0.485, 0.456, 0.406] )
std=np.array([0.229, 0.224, 0.225])
model=(model*std+mean)
#print(model)
cv2.imshow("ABC", model)
#waits for user to press any key
#(this is necessary to avoid Python kernel form crashing)
cv2.waitKey(0)
#closing all open windows
cv2.destroyAllWindows()
答案 2 :(得分:0)
我收到此错误,不知道这意味着什么
ValueError Traceback (most recent call last)
in
288
289 # test on validation
--> 290 predictions = test(dataset_valid, dataloader_valid)
291 accuracy_valid = 100. * predictions.eq(dataset_valid.dataset.targets[dataset_valid.indices].to(device)).sum().float() / len(dataset_valid)
292
in test(dataset, dataloader)
236 print("Predicted Label")
237 print(predictions[sampleno])
--> 238 showimages(inputs[sampleno].cpu())
239 return predictions
240
in showimages(model)
240
241 def showimages(model):
--> 242 model=np.transpose(model.numpy(),(1,2,0))
243
244
<__array_function__ internals> in transpose(*args, **kwargs)
~/.local/lib/python3.7/site-packages/numpy/core/fromnumeric.py in transpose(a, axes)
649
650 """
--> 651 return _wrapfunc(a, 'transpose', axes)
652
653
~/.local/lib/python3.7/site-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
59
60 try:
---> 61 return bound(*args, **kwds)
62 except TypeError:
63 # A TypeError occurs if the object does have such a method in its
ValueError: axes don't match array