我的列表索引超出范围

时间:2018-04-03 10:00:34

标签: python python-3.x tensorflow machine-learning deep-learning

我正在尝试使用python创建一个带tensorflow的图像分类器。但是,当我的索引超出范围时,我得到了这个奇怪的错误。该程序假设抓取文件读取前3个字母是否是火车,无论是猫还是狗。

import cv2
import numpy as np
import os
from random import shuffle
from tqdm import tqdm

TRAIN_DIR = 'C:\\Users\\cward\\Desktop\\images\\train'
TEST_DIR = 'C:\\Users\\cward\\Desktop\\images\\test'
IMG_SIZE = 50
LR = 1e-3

MODEL_NAME = 'dogsvscats-{}-{}.model'.format(LR, '2conv-basic')

def label_img(img):
    word_label = img.split('.')[-2]
    if word_label == 'cat': return[1,0]
    elif word_label == 'dog': return[0,1]

def create_train_data():
    training_data = []
    for img in tqdm(os.listdir(TRAIN_DIR)):
        label = label_img(img)
        path = os.path.join(TRAIN_DIR, img)
        img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), 
(IMG_SIZE,IMG_SIZE))
        training_data.append([np.array(img), np.array(label)])
    shuffle(traning_data)
    np.save('train_data.npy', traning_data)
    return training__data

def process_test_data():
    testing_data = []
    for img in tqdm(os.listdir(TRAIN_DIR)):
        path = os.path.join(TRAIN_DIR, img)
        img_num = img.split('.')[0]
        img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), 
(IMG_SIZE,IMG_SIZE))
        testing_data.append([np.array(img), img_num])
    np.save('test_data.npy',testing_data)
    return testing_data

train_data = create_train_data()

这是错误:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-34-40719067ea74> in <module>()
----> 1 train_data = create_train_data()

<ipython-input-32-88b70eb23645> in create_train_data()
      2     training_data = []
      3     for img in tqdm(os.listdir(TRAIN_DIR)):
----> 4         label = label_img(img)
      5         path = os.path.join(TRAIN_DIR, img)
      6         img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (IMG_SIZE,IMG_SIZE))

<ipython-input-31-82bc72a4ed99> in label_img(img)
      1 def label_img(img):
----> 2     word_label = img.split('.')[-2]
      3     if word_label == 'cat': return[1,0]
      4     elif word_label == 'dog': return[0,1]

IndexError: list index out of range

我是python的新手,所以请原谅我可怕的格式化!

1 个答案:

答案 0 :(得分:2)

错误是说img.split('。')的长度小于2

你有TRAIN_DIR内的任何目录吗?那会触发这个错误。我个人的建议是先试试:

try:
    label = label_img(img)
except IndexError:
    print(img)
    continue

这应该打印出会触发错误的所有img值的列表。可能是图像文件缺少扩展名的情况。确定错误并修复任何文件后,您可以执行以下操作:

if len(img.split('.')) < 2:
    continue
lable = label_img(img)

这会导致代码忽略会触发错误的文件。这样,如果您有任何子目录,您的代码仍然可以工作(尽管仍然会忽略子目录中的图像)