如何在sklearn中对不平衡数据集执行交叉验证

时间:2019-03-30 20:59:41

标签: python machine-learning scikit-learn classification

我的数据集高度不平衡,我想执行二进制分类。

阅读某些帖子时,我发现sklearn为不平衡的数据集提供了class_weight="balanced"。因此,我的分类器代码如下。

clf=RandomForestClassifier(random_state = 42, class_weight="balanced")

然后我使用上述分类器进行了十次交叉验证,如下所示。

k_fold = KFold(n_splits=10, shuffle=True, random_state=42)
new_scores = cross_val_score(clf, X, y, cv=k_fold, n_jobs=1)
print(new_scores.mean())

但是,我不确定class_weight="balanced"是否通过10倍交叉验证得到体现。我做错了吗?如果是这样,在sklearn中还有更好的方法吗?

很高兴在需要时提供更多详细信息。

1 个答案:

答案 0 :(得分:4)

您可能要使用分层交叉验证,而不是常规交叉验证。具体来说,您可以使用StratifiedKFold。 而不是代码中的KFold

这样可以确保所有潜在的训练和测试拆分都能捕获班级失衡。