我正在尝试使用经过训练的模型,使用python 3中的带有keras和tensorflow作为后端的predict_generator来预测数百万个图像。生成器和模型预测可以正常工作,但是目录中的某些图像已损坏或损坏,并使predict_generator停止并引发错误。删除图像后,它将再次起作用,直到下一个损坏的/损坏的图像通过该功能提供为止。
由于图像太多,因此运行脚本来打开每个图像并删除引发错误的图像是不可行的。有没有一种方法可以将“如果损坏则跳过图像”参数合并到生成器或目录函数的流中?
任何帮助将不胜感激!
答案 0 :(得分:0)
ImageDataGenerator
中没有这样的参数,flow_from_directory
方法中也没有这样的参数,因为您可以同时看到range和here的Keras文档。一种解决方法是扩展ImageDataGenerator
类并重载flow_from_directory
方法,以检查在生成器中生成图像之前是否损坏了图像。 here,您可以找到它的源代码。
答案 1 :(得分:0)
由于它是在预测期间发生的,因此,如果跳过任何图像或批次,则需要跟踪跳过的图像,以便将预测分数正确映射到图像文件名。
基于这个想法,我的DataGenerator是通过有效的图像索引跟踪器实现的。尤其要注意变量valid_index
,该变量跟踪有效图像的索引。
class DataGenerator(keras.utils.Sequence):
def __init__(self, df, batch_size, verbose=False, **kwargs):
self.verbose = verbose
self.df = df
self.batch_size = batch_size
self.valid_index = kwargs['valid_index']
self.success_count = self.total_count = 0
def __len__(self):
return int(np.ceil(self.df.shape[0] / float(self.batch_size)))
def __getitem__(self, idx):
print('generator is loading batch ',idx)
batch_df = self.df.iloc[idx * self.batch_size:(idx + 1) * self.batch_size]
self.total_count += batch_df.shape[0]
# return a list whose element is either an image array (when image is valid) or None(when image is corrupted)
x = load_batch_image_to_arrays(batch_df['image_file_names'])
# filter out corrupted images
tmp = [(u, i) for u, i in zip(x, batch_df.index.values.tolist()) if
u is not None]
# boundary case. # all image failed, return another random batch
if len(tmp) == 0:
print('[ERROR] All images loading failed')
# based on https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py#L621,
# Keras will automatically find the next batch if it returns None
return None
print('successfully loaded image in {}th batch {}/{}'.format(str(idx), len(tmp), self.batch_size))
self.success_count += len(tmp)
x, batch_index = zip(*tmp)
x = np.stack(x) # list to np.array
self.valid_index[idx] = batch_index
# follow preprocess input function provided by keras
x = resnet50_preprocess(np.array(x, dtype=np.float))
return x
def on_epoch_end(self):
print('total image count', self.total_count)
print('successful images count', self.success_count)
self.success_count = self.total_count = 0 # reset count after one epoch ends.
在预测期间。
predictions = model.predict_generator(
generator=data_gen,
workers=10,
use_multiprocessing=False,
max_queue_size=20,
verbose=1
).squeeze()
indexes = []
for i in sorted(data_gen.valid_index.keys()):
indexes.extend(data_gen.valid_index[i])
result_df = df.loc[indexes]
result_df['score'] = predictions