在Keras中(使用TensorFlow作为后端)我正在构建一个模型,该模型使用具有高度不平衡类(标签)的庞大数据集。为了能够运行培训过程,我创建了一个生成器,将数据块提供给fit_generator
。
根据fit_generator的文档,生成器的输出可以是元组(inputs, targets)
或元组(inputs, targets, sample_weights)
。考虑到这一点,这里有几个问题:
class_weight
考虑整个数据集的所有类的权重,而
sample_weights
考虑每个块的所有类的权重
由发电机创造。那是对的吗?如果没有,有人可以详细说明此事吗?class_weight
提供给fit_generator
,然后将sample_weights
作为每个块的输出?如果是,那为什么呢?如果不是那么哪一个更好?sample_weights
,如果特定块中缺少某些类,如何映射权重?让我举个例子。在我的整个数据集中,我有7个可能的类(标签)。因为这些类是高度不平衡的,所以当我创建较小的数据块作为fit_generator
的输出时,特定块中缺少某些类。我应该如何为这些块创建sample_weights
?答案 0 :(得分:13)
我的理解是class_weight考虑了所有权重 整个数据集的类,而sample_weights则关注 由每个单独的块创建的所有类的权重 发电机。那是对的吗?如果没有,有人可以详细说明 重要?
class_weight
会影响目标函数计算中每个类的相对权重。 sample_weights
,顾名思义,允许进一步控制属于同一类的样本的相对权重。
是否有必要将class_weight同时提供给fit_generator和 那么sample_weights作为每个块的输出?如果是,那为什么呢? 如果不是那么哪一个更好?
这取决于您的申请。在对高度偏斜的数据集进行训练时,类权重很有用;例如,用于检测欺诈性交易的分类器。如果您对批次中的样品没有相同的置信度,则样品重量非常有用。一个常见的例子是对可变不确定性的测量进行回归。
如果我应该为每个块提供sample_weights,我该如何映射 如果特定块中缺少某些类的权重?让 我举个例子。在我的整体数据集中,我有7个可能的类 (标签)。因为这些类是非常不平衡的,当我创建时 较小的数据块作为fit_generator的输出,其中一些 特定块中缺少类。我该怎么创造 这些块的sample_weights?
这不是问题。 sample_weights
是基于每个样本定义的,并且独立于类。出于这个原因,documentation表示(inputs, targets, sample_weights)
应该是相同的长度。
_weighted_masked_objective
中的function engine/training.py
有一个sample_weights示例正在应用。