如何计算GridSearchCV将训练多少个模型?

时间:2018-02-09 20:02:27

标签: python scikit-learn grid-search

您如何计算出SKLearn的GridSearchCV将训练多少型号?在我的情况下,我使用以下参数:

learning_rate_range = [0.01, 0.05, 0.1]
max_depth_range = [3, 4, 5, 6, 7]
min_child_weight_range = [6, 7, 8]
subsample_range = [0.6, 0.7, 0.8, 0.9]
colsample_range = [0.7, 0.8, 0.9]

例如,如果您使用的是3次交叉验证,那么总共需要培训多少个模型,以及用于解决此问题的一般方法是什么?

2 个答案:

答案 0 :(得分:3)

根据文档:" GridSearchCV详尽地考虑所有参数组合,而RandomizedSearchCV可以从具有指定分布的参数空间中抽取给定数量的候选者。"。

http://scikit-learn.org/stable/modules/grid_search.html#grid-search

GridSearchCV的一个例子:

http://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_digits.html#sphx-glr-auto-examples-model-selection-plot-grid-search-digits-py

如果您在一个字典中传递上述所有参数,您将获得3x5x3x4x3网格点,并且每个点将交叉验证3次。

答案 1 :(得分:1)

正如@KRKirov所说,参数的总数只是每个参数各自级别的乘积。 SciKit学习提供了一种了解参数总数的简单方法,如下所示:

from sklearn.model_selection import ParameterGrid

parameters = {
learning_rate_range: [0.01, 0.05, 0.1]
max_depth_range: [3, 4, 5, 6, 7]
min_child_weight_range: [6, 7, 8]
subsample_range: [0.6, 0.7, 0.8, 0.9]
colsample_range: [0.7, 0.8, 0.9]
}

grid = ParameterGrid(parameters)
# python 3.6+ for the f format
print (f"The total number of parameters-combinations is: {len(grid)}")

请记住,每个参数组合都会执行5次以进行交叉验证。因此,总执行次数将为5 * len(grid)