SciKit-Learn中的自定义记分员 - 允许对特定类进行网格搜索优化

时间:2015-02-03 16:08:51

标签: python machine-learning scikit-learn cross-validation

我想在SciKit-Learn中创建一个自定义分数器,我可以传递给GridSearchCV,它根据特定类的预测准确性来评估模型性能。

假设我的训练数据包含属于三个类别之一的数据点:

  

'dog','cat','mouse'

# Create a classifier:
clf = ensemble.RandomForestClassifier()

# Set up some parameters to explore:
param_dist =    {
                 'n_estimators':[500, 1000, 2000, 4000],
                 "criterion": ["gini", "entropy"],
                 'bootstrap':[True, False]
                }

# Construct grid search
search = GridSearchCV(clf,\
                      param_grid=param_dist,\
                      cv=StratifiedKFold(y, n_folds=10),\
                      scoring=my_scoring_function)


# Perform search
X = training_data
y = ground_truths
search.fit(X, y)

有没有办法构建my_scoring_function,这样只返回'dog'类预测的准确性? make_scorer function似乎是有限的,因为它只处理每个数据点的基本事实和预测类。

非常感谢你的帮助!

1 个答案:

答案 0 :(得分:1)

我错过了sklearn文档中的一个部分。

您可以创建一个需要以下输入的功能; model,x_test,y_test,并输出0到1之间的值(其中1表示最佳),可以用作优化函数。

只需创建函数,应用model.predict(x_test),然后使用精度等指标分析结果。