具有数据增强功能的Keras ImageDataGenerator sample_weight

时间:2019-03-08 10:58:05

标签: keras

我对使用ImageDataGenerator在Keras中进行数据增强时使用sample_weight参数有疑问。假设我有一系列带有一个对象类别的简单图像。因此,对于每个图像,我将有一个对应的蒙版,其背景像素= 0,对象标记位置为1。

但是,此数据集不平衡,因为这些图像中有很大一部分是空的,这意味着蒙版仅包含0。 如果我了解得很好,那么ImageDataGenerator的flow方法的“ sample_weight”参数将把重点放在我发现更有趣的数据集样本上,即对象所在的位置。

我的问题是:这个sample_weight参数对我的模型训练有什么具体影响?它会影响数据扩充吗?如果我使用'validation_split'参数,是否会影响验证集的生成方式?

这是我的问题所涉及的代码的一部分:

data_gen_args = dict(rotation_range=90,
                     width_shift_range=0.4,
                     height_shift_range=0.4,
                     zoom_range=0.4,
                     horizontal_flip=True,
                     fill_mode='reflect',
                     rescale=1. / 255,
                     validation_split=0.2,
                     data_format='channels_last'
)    

image_datagen = ImageDataGenerator(**data_gen_args)


imf = image_datagen.flow(
    x=stacked_images_channel,
    y=stacked_masks_channel,
    batch_size=batch_size,
    shuffle=False,
    seed=seed,subset='training',
    sample_weight = sample_weight,
    save_to_dir = 'traindir',
    save_prefix = 'train_'
)

valf = image_datagen.flow(
    x=stacked_images_channel,
    y=stacked_masks_channel,
    batch_size=batch_size,
    shuffle=False,
    seed=seed,subset='validation',
    sample_weight = sample_weight,
    save_to_dir = 'valdir',
    save_prefix = 'val_'
)

STEP_SIZE_TRAIN=imf.n//imf.batch_size
STEP_SIZE_VALID=valf.n//valf.batch_size

model = unet.UNet2(numberOfClasses, imshape, '', learningRate, depth=4)

history = model.fit_generator(generator=imf,
                    steps_per_epoch=STEP_SIZE_TRAIN,
                    epochs=epochs,
                    validation_data=valf,
                    validation_steps=STEP_SIZE_VALID,
                    verbose=2
)

预先感谢您的关注。

1 个答案:

答案 0 :(得分:0)

对于在1.1.0进行预处理的Keras 2.2.5,sample_weight与样本一起传递并在处理期间应用。调用.fit_generator时,模型是each batch using sample weights分批训练的:

model.train_on_batch(x, y,
                     sample_weight=sample_weight,
                     class_weight=class_weight)

.train_on_batch的源代码中,documentation states:“ sample_weight:与x长度相同的可选数组,其中包含权重,该权重应用于每个样本的模型损失。(...) ”。重量的实际应用是在计算每批次的损失时发生的。编译模型时,Keras会从所需的损失函数中生成“加权损失”函数。加权计算在code中表示为:

def weighted(y_true, y_pred, weights, mask=None):
        """Wrapper function.
        # Arguments
            y_true: `y_true` argument of `fn`.
            y_pred: `y_pred` argument of `fn`.
            weights: Weights tensor.
            mask: Mask tensor.
        # Returns
            Scalar tensor.
        """
        # score_array has ndim >= 2
        score_array = fn(y_true, y_pred)
        if mask is not None:
            # Cast the mask to floatX to avoid float64 upcasting in Theano
            mask = K.cast(mask, K.floatx())
            # mask should have the same shape as score_array
            score_array *= mask
            #  the loss per batch should be proportional
            #  to the number of unmasked samples.
            score_array /= K.mean(mask) + K.epsilon()

        # apply sample weighting
        if weights is not None:
            # reduce score_array to same ndim as weight array
            ndim = K.ndim(score_array)
            weight_ndim = K.ndim(weights)
            score_array = K.mean(score_array,
                                 axis=list(range(weight_ndim, ndim)))
            score_array *= weights
            score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))
        return K.mean(score_array)

此包装器显示,它首先计算所需的损失(调用fn(y_true, y_pred)),然后在传递权重(使用sample_weightclass_weight)时应用称重。

考虑到这种情况:

这个sample_weight参数对我的模型训练有什么具体影响。

重量基本上乘以损失(并归一化)。因此,“较重”的权重(超过1个)样本会导致更多的损失,因此梯度更大。 “轻”重量降低了样品的重要性,并导致较小的梯度。

它会影响数据扩充吗?

这取决于您的意思。根据经验,我可以在提供Keras数据生成器之前在 之前执行增强操作(据我所知,在预处理1.1.0中仍然存在问题)。

  • 当将已经扩充的数据馈送到生成器时,.flow调用将需要一个样本权重列表,只要输入数据即可。因此,加权对增强的影响取决于权重的选择方式。扩展N次的数据点可以为每个扩展分配相同的权重,或者根据意图分配1 / N的权重。
  • Keras中的默认行为似乎为Keras执行的每个增强(转换)分配了相同的权重。 code看起来很清晰,尽管我从不依赖它。

如果我使用'validation_split'参数,是否会影响验证集的生成方式?

sample_weight参数似乎不会干扰validation_split。我没有专门研究代码,但是拆分基本上会获取输入数据,并保留拆分以进行验证-无论数据是什么。添加sample_weight时,每个数据点的变化是:没有权重,数据为(x, y);权重越高,数据就变成(x, y, weight)