在cross_val_score中,参数cv的使用方式有何不同?

时间:2017-10-01 17:21:19

标签: python python-3.x machine-learning scikit-learn cross-validation

我正在尝试计算如何进行k折交叉验证。我希望有人可以告诉我我的两份印刷声明之间的区别。他们给了我很多不同的数据,我认为他们会是一样的。

##train is my training data, 
##target is my target, my binary class.

dtc = DecisionTreeClassifier()
kf = KFold(n_splits=10)
print(cross_val_score(dtc, train, target, cv=kf, scoring='accuracy'))
print(cross_val_score(dtc, train, target, cv=10, scoring='accuracy'))

1 个答案:

答案 0 :(得分:1)

DecisionTreeClassifier来自ClassifierMixin,正如文档中提到的那样(强调我的):

Computing cross-validated metrics

  

cv参数为整数时,cross_val_score默认使用KFoldStratifiedKFold策略,如果估算器派生自{,则使用后者{1}}

所以,当您通过ClassifierMixin时,您正在使用cv=10策略,而在传递StratifiedKFold时,您使用的是常规cv=kf策略。

在分类中,分层通常试图确保每个测试折叠具有近似相等的类别表示。有关详细信息,请参阅交叉验证Understanding stratified cross-validation