如何处理fit_generator的自定义Keras序列中的异常?

时间:2019-07-07 16:50:59

标签: python keras generator tf.keras

我实现了自己的序列(from tensorflow.python.keras.utils import Sequence)。但是,某些错误的文件导致__getitem__处出现异常。如果我使用model.fit_generator,则在出现异常时培训过程将停止。

我想处理异常,就像发生异常时跳过该批次一样简单。

class DatasetSequence(Sequence):

    def __init__(self, image_path_list, density_path_list, random_crop_size=None):
        self.image_path_list = image_path_list
        self.density_path_list = density_path_list
        self.random_crop_size = random_crop_size
        self.batch_size = 1

    def __len__(self):
        return len(self.image_path_list)

    def __getitem__(self, idx):
        image_path = self.image_path_list[idx]
        density_path = self.density_path_list[idx]

        density = load_density(density_path)
        image = np.array(Image.open(image_path, "r").convert("RGB"))
        density = np.expand_dims(density, axis=3)  # add channel dim

        if self.random_crop_size is not None:
            # print("crop ", self.random_crop_size)
            image, density = random_crop(image, density, self.random_crop_size)        

        # preprocess vgg16 input
        im = image
        im = im/255.0
        im[:,:,0]=(im[:,:,0]-0.485)/0.229
        im[:,:,1]=(im[:,:,1]-0.456)/0.224
        im[:,:,2]=(im[:,:,2]-0.406)/0.225
        image = im

        # density = np.expand_dims(density, axis=3)  # add channel dim
        image = np.expand_dims(image, axis=0) # add batch dim
        density = np.expand_dims(density, axis=0) # add batch dim

        return image, density

0 个答案:

没有答案