经过CNN模型训练后,我有一个预测犬种的代码,我从下面的函数中获得类别索引。我想显示从该函数获取的idx
类文件夹中的随机图像。
class_name = [item for item in loaders['train'].dataset.classes]
def predict_dog_breed(img,model,class_names):
image = Image.open(img).convert('RGB')
transform = transforms.Compose([
transforms. RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229, 0.224, 0.225])])
image = transform(image)
test_image = image.unsqueeze(0)
net.eval()
output = net(test_image)
idx = torch.argmax(output)
a = random.choice(os.listdir("./dogImages/train/{}/".format (class_name[idx])))
imshow(a)
return class_name[idx]
当我尝试显示随机图像时,出现以下错误:
在os.listdir('./ images')中img_file的TypeError跟踪(最近一次调用为最后一次):2 image = os.path.join('./ images',img_file)----> 3 dog_or_human(image)
在dog_or_human(img)中的5 plt.show()6如果dog_detector(img)==真:----> 7 Forecast_dog = Forecast_dog_breed(img,net,class_name)8 print(“检测到狗!该品种是{}“。format(predict_dog))9 elif face_detector(img)> 0:
predict_dog_breed(img,model,class_name)18 a = random.choice(os.listdir(“ ./ dogImages / train / {} /”。format(class_name [idx])))19 print(a)- -> 20 imshow(a)21 #subdir =''.join([“ / dogImages / train /”,class_name [idx]])22 #print(file)
〜/ Library / Python / 3.7 / lib / python / site-packages / matplotlib / pyplot.py in imshow(X,cmap,norm,aspect,插值,alpha,vmin,vmax,原点,范围,形状,filternorm ,filterrad,imlim,resample,url,data,** kwargs)2697 filternorm = filternorm,filterrad = filterrad,imlim = imlim,2698 resample = resample,url = url,**({“ data”:data})(如果data为不是-> 2699其他{}),** kwargs)2700 sci(__ ret)2701返回__ret
〜/ Library / Python / 3.7 / lib / python / site-packages / matplotlib / init.py in inner(ax,data,* args,** kwargs)1808“ Matplotlib list!)”%(label_namer, func.name),1809 RuntimeWarning,stacklevel = 2)-> 1810 return func(ax,* args,** kwargs)1811 1812 inner.doc = _add_data_doc(inner.doc,
〜/ Library / Python / 3.7 / lib / python / site-packages / matplotlib / axes / _axes.py in imshow(自身,X,cmap,范数,方面,插值,alpha,vmin,vmax,原点,范围,shape,filternorm,filterrad,imlim,resample,url,** kwargs)5492 resample = resample,** kwargs)5493-> 5494 im.set_data(X)5495 im.set_alpha(alpha)5496如果是im.get_clip_path()没有:
〜/ Library / Python / 3.7 / lib / python / site-packages / matplotlib / image.py in set_data(self,A)632 if(self._A.dtype!= np.uint8和633 not np.can_cast (self._A.dtype,float,“ same_kind”)):-> 634引发TypeError(“图像数据无法转换为float”)635636(否则)(self._A.ndim == 2
TypeError:图像数据无法转换为浮点数
任何帮助,将不胜感激!
答案 0 :(得分:1)
因此,我尝试在您的代码here中重现该错误,并成功地做到了这一点。由于代码中的以下行,您会收到错误消息:
a = random.choice(os.listdir("./dogImages/train/{}/".format(class_name[idx])))
imshow(a)
random.choice(os.listdir("./dogImages/train/{}/".format(class_name[idx])))
主要返回图像文件名,它是一个字符串。您不是在读取图像,只是将文件名传递给imshow
函数,这是不正确的。请检查以下数字以进行澄清。
错误代码:
没有错误的代码:
因此,将您的predict_do_breed
函数更改为以下内容:
def predict_dog_breed(img,model,class_name):
image = Image.open(img).convert('RGB')
transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
image = transform(image)
test_image = image.unsqueeze(0)
net.eval()
output = net(test_image)
idx = torch.argmax(output)
a = random.choice(os.listdir("./dogImages/train/{}/".format(class_name[idx])))
print(a)
img = cv2.imread("./dogImages/train/{}/".format(class_name[idx])+a)
imshow(img)
return class_name[idx]
在上面的代码中,cv2.imread
函数用于读取random.choice(os.listdir("./dogImages/train/{}/".format(class_name[idx])))
输出的图像文件名。