Keras如何使用class_weight参数?

时间:2018-08-30 19:16:49

标签: python keras

Keras使用class_weight参数处理不平衡的数据集。

我们可以在doc中找到以下内容:

  

可选的字典将类索引(整数)映射到权重(浮点数),以应用于训练过程中该类样本的模型损失。这可能有助于告诉模型“更多关注”来自代表性不足的类的样本。

这是否意味着class_weight在每个班级的训练误差函数中赋予不同的权重?它对其他地方有影响吗?与代表最多的类的“物理”放置实例相比,它对防止泛化错误真的有效吗?

1 个答案:

答案 0 :(得分:2)

class_weight 参数将与每个训练示例相关的损失加权,与该类别在训练集中的代表性不足成正比。这样可以防止训练期间班级不平衡,并使您的网络对于泛化错误具有鲁棒性。

在物理上删除与最具代表性的类相对应的数据实例时,我会格外小心-如果您的网络很深,因此具有很强的表示能力,那么剔除数据集可能会导致过拟合,因此对验证/测试集。

我建议使用Keras文档中指定的class_weights参数。如果您确实打算从最具代表性的类中删除数据实例,请确保调整网络拓扑以减少模型的表示能力(即添加Dropout和/或L2正则化层)。