我在keras中使用序列生成器来从磁盘并行获取数据,但是我得到了一个非常奇怪的错误。
所以,这是我的序列生成器代码
class detracSequence(Sequence):
def __init__(self, x_set, y_set, bbox_set, batch_size):
self.x, self.y, self.bbox = x_set, y_set, bbox_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
print 'index range', idx*self.batch_size, 'till', (idx+1)*self.batch_size
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_bbox = self.bbox[idx * self.batch_size:(idx + 1) * self.batch_size]
# print batch_x
imgs = np.ndarray((self.batch_size,128,128,3))
for file_index in range(self.batch_size):
temp = cv2.imread(batch_x[file_index])
if temp.shape[0] == 0:
print '1', batch_x[file_index]
# print '1', temp.shape
#print(temp)
x1, x2, x3, x4 = batch_bbox[file_index,0],batch_bbox[file_index,1],batch_bbox[file_index,2], batch_bbox[file_index,3]
#print batch_x[file_index]
temp_ = temp[int(x2):int(x4),int(x1):int(x3)]
imgs[file_index] = cv2.resize(temp_,(128,128))
return imgs, np.array(batch_y)
这是调用此生成器的代码。
Xtrain_gen = detracSequence(X_train,y_train,training_coordinates, batch_size=32)
history = model.fit_generator(generator=Xtrain_gen, epochs=20, validation_data=Xvalidation_gen,callbacks=callbacks_list,use_multiprocessing=True)
现在,问题是idx的值是由内部代码生成的。我的期望是它将处理索引限制。但是在方法 getitem (self,idx)中,我得到一个idx的值,它给出了索引超出范围的错误,如下所示,这有点奇怪。这是错误日志
Traceback (most recent call last):
File "finetuneInceptionV3.py", line 112, in <module>
history = model.fit_generator(generator=Xtrain_gen, epochs=20, validation_data=Xvalidation_gen,callbacks=callbacks_list,use_multiprocessing=True)
File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/keras/engine/training.py", line 2192, in fit_generator
generator_output = next(output_generator)
File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/keras/utils/data_utils.py", line 584, in get
six.raise_from(StopIteration(e), e)
File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/six.py", line 737, in raise_from
raise value
StopIteration: list index out of range
现在,我不知道如果不进入源代码就解决这个问题,但我不希望这种情况发生。谁能告诉我,如果我在这里遗漏了什么?
答案 0 :(得分:1)
您假设您的数据可以完全按batch_size
划分,这可能不一定是这种情况,因为最后一批可能比使用切片时小batch_size
。不要使用固定范围,而是使用切片的大小:
imgs = np.ndarray((len(batch_x),128,128,3)) # here
for file_index in range(len(batch_x)): # and here
temp = cv2.imread(batch_x[file_index])
所以你的指数从不高于你的指数。