我实现了自己的序列(from tensorflow.python.keras.utils import Sequence
)。但是,某些错误的文件导致__getitem__
处出现异常。如果我使用model.fit_generator
,则在出现异常时培训过程将停止。
我想处理异常,就像发生异常时跳过该批次一样简单。
class DatasetSequence(Sequence):
def __init__(self, image_path_list, density_path_list, random_crop_size=None):
self.image_path_list = image_path_list
self.density_path_list = density_path_list
self.random_crop_size = random_crop_size
self.batch_size = 1
def __len__(self):
return len(self.image_path_list)
def __getitem__(self, idx):
image_path = self.image_path_list[idx]
density_path = self.density_path_list[idx]
density = load_density(density_path)
image = np.array(Image.open(image_path, "r").convert("RGB"))
density = np.expand_dims(density, axis=3) # add channel dim
if self.random_crop_size is not None:
# print("crop ", self.random_crop_size)
image, density = random_crop(image, density, self.random_crop_size)
# preprocess vgg16 input
im = image
im = im/255.0
im[:,:,0]=(im[:,:,0]-0.485)/0.229
im[:,:,1]=(im[:,:,1]-0.456)/0.224
im[:,:,2]=(im[:,:,2]-0.406)/0.225
image = im
# density = np.expand_dims(density, axis=3) # add channel dim
image = np.expand_dims(image, axis=0) # add batch dim
density = np.expand_dims(density, axis=0) # add batch dim
return image, density