如何在GridSearchCV中执行stratifiedShuffleSplit?

时间:2018-02-02 15:36:05

标签: python python-3.x

我可以在GridSearchCV中运行StraitifiedShuffleSplit,而不必在我的代码中首先将其实例化为“ss”。

ss = StratifiedShuffleSplit(n_splits=3, test_size=0.5, random_state=0)

grid_search = GridSearchCV(clf_us, param_grid = {parameter: num_range},cv=ss)

1 个答案:

答案 0 :(得分:2)

如果您正在构建分类器并且只关心在每个折叠中保持与完整数据集中相同的标签平衡,则可以通过指定GridSearchCV中的折叠数来避免实例化StratifiedShuffleSplit,例如: CV = 5。

根据文档:“对于整数/无输入,如果估计器是分类器,y是二进制或多类,则使用StratifiedKFold。在所有其他情况下,使用KFold。“ http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html

但是,如果您希望更好地控制数据拆分,则无法避免实例化StratifiedShuffleSplit。请参阅此页面中的示例,了解test_size参数如何影响拆分:http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.ShuffleSplit.html#sklearn.model_selection.ShuffleSplit