Keras - fit_generator

时间:2017-04-17 20:46:44

标签: tensorflow keras

在Keras中(使用TensorFlow作为后端)我正在构建一个模型,该模型使用具有高度不平衡类(标签)的庞大数据集。为了能够运行培训过程,我创建了一个生成器,将数据块提供给fit_generator

根据fit_generator的文档,生成器的输出可以是元组(inputs, targets)或元组(inputs, targets, sample_weights)。考虑到这一点,这里有几个问题:

  1. 我的理解是这样的 class_weight考虑整个数据集的所有类的权重,而 sample_weights考虑每个块的所有类的权重 由发电机创造。那是对的吗?如果没有,有人可以详细说明此事吗?
  2. 是否有必要同时将class_weight提供给fit_generator,然后将sample_weights作为每个块的输出?如果是,那为什么呢?如果不是那么哪一个更好?
  3. 如果我应该为每个块提供sample_weights,如果特定块中缺少某些类,如何映射权重?让我举个例子。在我的整个数据集中,我有7个可能的类(标签)。因为这些类是高度不平衡的,所以当我创建较小的数据块作为fit_generator的输出时,特定块中缺少某些类。我应该如何为这些块创建sample_weights

1 个答案:

答案 0 :(得分:13)

  

我的理解是class_weight考虑了所有权重   整个数据集的类,而sample_weights则关注   由每个单独的块创建的所有类的权重   发电机。那是对的吗?如果没有,有人可以详细说明   重要?

class_weight会影响目标函数计算中每个类的相对权重。 sample_weights,顾名思义,允许进一步控制属于同一类的样本的相对权重。

  

是否有必要将class_weight同时提供给fit_generator和   那么sample_weights作为每个块的输出?如果是,那为什么呢?   如果不是那么哪一个更好?

这取决于您的申请。在对高度偏斜的数据集进行训练时,类权重很有用;例如,用于检测欺诈性交易的分类器。如果您对批次中的样品没有相同的置信度,则样品重量非常有用。一个常见的例子是对可变不确定性的测量进行回归。

  

如果我应该为每个块提供sample_weights,我该如何映射   如果特定块中缺少某些类的权重?让   我举个例子。在我的整体数据集中,我有7个可能的类   (标签)。因为这些类是非常不平衡的,当我创建时   较小的数据块作为fit_generator的输出,其中一些   特定块中缺少类。我该怎么创造   这些块的sample_weights?

这不是问题。 sample_weights是基于每个样本定义的,并且独立于类。出于这个原因,documentation表示(inputs, targets, sample_weights)应该是相同的长度。

_weighted_masked_objective中的function engine/training.py有一个sample_weights示例正在应用。