我正在尝试使用python创建一个带tensorflow的图像分类器。但是,当我的索引超出范围时,我得到了这个奇怪的错误。该程序假设抓取文件读取前3个字母是否是火车,无论是猫还是狗。
import cv2
import numpy as np
import os
from random import shuffle
from tqdm import tqdm
TRAIN_DIR = 'C:\\Users\\cward\\Desktop\\images\\train'
TEST_DIR = 'C:\\Users\\cward\\Desktop\\images\\test'
IMG_SIZE = 50
LR = 1e-3
MODEL_NAME = 'dogsvscats-{}-{}.model'.format(LR, '2conv-basic')
def label_img(img):
word_label = img.split('.')[-2]
if word_label == 'cat': return[1,0]
elif word_label == 'dog': return[0,1]
def create_train_data():
training_data = []
for img in tqdm(os.listdir(TRAIN_DIR)):
label = label_img(img)
path = os.path.join(TRAIN_DIR, img)
img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE),
(IMG_SIZE,IMG_SIZE))
training_data.append([np.array(img), np.array(label)])
shuffle(traning_data)
np.save('train_data.npy', traning_data)
return training__data
def process_test_data():
testing_data = []
for img in tqdm(os.listdir(TRAIN_DIR)):
path = os.path.join(TRAIN_DIR, img)
img_num = img.split('.')[0]
img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE),
(IMG_SIZE,IMG_SIZE))
testing_data.append([np.array(img), img_num])
np.save('test_data.npy',testing_data)
return testing_data
train_data = create_train_data()
这是错误:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-34-40719067ea74> in <module>()
----> 1 train_data = create_train_data()
<ipython-input-32-88b70eb23645> in create_train_data()
2 training_data = []
3 for img in tqdm(os.listdir(TRAIN_DIR)):
----> 4 label = label_img(img)
5 path = os.path.join(TRAIN_DIR, img)
6 img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (IMG_SIZE,IMG_SIZE))
<ipython-input-31-82bc72a4ed99> in label_img(img)
1 def label_img(img):
----> 2 word_label = img.split('.')[-2]
3 if word_label == 'cat': return[1,0]
4 elif word_label == 'dog': return[0,1]
IndexError: list index out of range
我是python的新手,所以请原谅我可怕的格式化!
答案 0 :(得分:2)
错误是说img.split('。')的长度小于2
你有TRAIN_DIR内的任何目录吗?那会触发这个错误。我个人的建议是先试试:
try:
label = label_img(img)
except IndexError:
print(img)
continue
这应该打印出会触发错误的所有img值的列表。可能是图像文件缺少扩展名的情况。确定错误并修复任何文件后,您可以执行以下操作:
if len(img.split('.')) < 2:
continue
lable = label_img(img)
这会导致代码忽略会触发错误的文件。这样,如果您有任何子目录,您的代码仍然可以工作(尽管仍然会忽略子目录中的图像)