为什么GridSearchCv在同一代码中表现不同

时间:2019-03-23 20:51:31

标签: python machine-learning scikit-learn gridsearchcv

我正在尝试致电GridSearchCV以获取最佳估算器 如果我这样调用参数

clf = DecisionTreeClassifier(random_state=42)

parameters = {'max_depth':[2,3,4,5,6,7,8,9,10],\
'min_samples_leaf':[2,3,4,5,6,7,8,9,10],\
'min_samples_split':[2,3,4,5,6,7,8,9,10]}

scorer = make_scorer(f1_score)

grid_obj = GridSearchCV(clf, parameters, scoring=scorer)

grid_fit = grid_obj.fit(X_train, y_train)

best_clf = grid_fit.best_estimator_

best_clf.fit(X_train, y_train)

best_train_predictions = best_clf.predict(X_train)
best_test_predictions = best_clf.predict(X_test)

print('The training F1 Score is', f1_score(best_train_predictions, y_train))
print('The testing F1 Score is', f1_score(best_test_predictions, 
y_test))

结果应为

The training F1 Score is 0.784810126582
The testing F1 Score is 0.72

对于相同的数据,结果将与此不同 我只将[2,3,4,5,6,7,8,9,10]更改为[2,4,6,8,10]

clf = DecisionTreeClassifier(random_state=42)

parameters = {'max_depth':[2,4,6,8,10],'min_samples_leaf':[2,4,6,8,10],\
          'min_samples_split':[2,4,6,8,10] }

scorer = make_scorer(f1_score)

grid_obj = GridSearchCV(clf, parameters, scoring=scorer)
grid_fit = grid_obj.fit(X_train, y_train)
best_clf = grid_fit.best_estimator_
best_clf.fit(X_train, y_train)
best_train_predictions = best_clf.predict(X_train)
best_test_predictions = best_clf.predict(X_test)

print('The training F1 Score is', f1_score(best_train_predictions, y_train))
print('The testing F1 Score is', f1_score(best_test_predictions, y_test))

结果

The training F1 Score is 0.814814814815
The testing F1 Score is 0.8

困惑GridsearchCV的工作原理

1 个答案:

答案 0 :(得分:0)

通过更改gridsearch分析的值,您将评估并比较模型的不同超参数集。记住GridSearch最终要做的是选择最佳的超参数集。

因此,在您的代码中,grid_fit.best_estimator_可能是不同的模型,这很自然地解释了为什么它们会在训练和测试集上产生不同的分数。

您可能在第一种情况下

clf = DecisionTreeClassifier(max_depth = 3, min_samples_leaf = 5, min_samples_split = 9)

第二种情况

clf = DecisionTreeClassifier(max_depth = 2, min_samples_leaf = 4, min_samples_split = 8)

(要进行检查,您可以在每种情况下执行grid_fit.best_params_

但是,在第一种情况下,您确实应该有更大的分数,因为第二个网格搜索使用的是第一个参数的子集。就像上面提到的@ Attack68一样,这很可能是由于您无法在每个步骤中控制随机性。