序列生成器出错:StopIteration:列表索引超出范围

时间:2018-06-16 06:42:25

标签: python keras

我在keras中使用序列生成器来从磁盘并行获取数据,但是我得到了一个非常奇怪的错误。

所以,这是我的序列生成器代码

class detracSequence(Sequence):

    def __init__(self, x_set, y_set, bbox_set, batch_size):
        self.x, self.y, self.bbox = x_set, y_set, bbox_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        print 'index range', idx*self.batch_size, 'till', (idx+1)*self.batch_size
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_bbox = self.bbox[idx * self.batch_size:(idx + 1) * self.batch_size]

#        print batch_x
        imgs = np.ndarray((self.batch_size,128,128,3))
        for file_index in range(self.batch_size):
            temp = cv2.imread(batch_x[file_index])
            if temp.shape[0] == 0:
                print '1', batch_x[file_index]

           # print '1', temp.shape
            #print(temp)
            x1, x2, x3, x4 =  batch_bbox[file_index,0],batch_bbox[file_index,1],batch_bbox[file_index,2], batch_bbox[file_index,3]
            #print batch_x[file_index]
            temp_ = temp[int(x2):int(x4),int(x1):int(x3)]
            imgs[file_index] = cv2.resize(temp_,(128,128))

        return imgs, np.array(batch_y)

这是调用此生成器的代码。

Xtrain_gen = detracSequence(X_train,y_train,training_coordinates, batch_size=32)
history = model.fit_generator(generator=Xtrain_gen, epochs=20, validation_data=Xvalidation_gen,callbacks=callbacks_list,use_multiprocessing=True)

现在,问题是idx的值是由内部代码生成的。我的期望是它将处理索引限制。但是在方法 getitem (self,idx)中,我得到一个idx的值,它给出了索引超出范围的错误,如下所示,这有点奇怪。这是错误日志

Traceback (most recent call last):
  File "finetuneInceptionV3.py", line 112, in <module>
    history = model.fit_generator(generator=Xtrain_gen, epochs=20, validation_data=Xvalidation_gen,callbacks=callbacks_list,use_multiprocessing=True)
  File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/keras/engine/training.py", line 2192, in fit_generator
    generator_output = next(output_generator)
  File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/keras/utils/data_utils.py", line 584, in get
    six.raise_from(StopIteration(e), e)
  File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/six.py", line 737, in raise_from
    raise value
StopIteration: list index out of range

现在,我不知道如果不进入源代码就解决这个问题,但我不希望这种情况发生。谁能告诉我,如果我在这里遗漏了什么?

1 个答案:

答案 0 :(得分:1)

您假设您的数据可以完全按batch_size划分,这可能不一定是这种情况,因为最后一批可能比使用切片时小batch_size。不要使用固定范围,而是使用切片的大小:

imgs = np.ndarray((len(batch_x),128,128,3)) # here
for file_index in range(len(batch_x)): # and here
  temp = cv2.imread(batch_x[file_index])

所以你的指数从不高于你的指数。