如何在sklearn cross_validate函数的评分属性中传递参数?

时间:2019-02-15 22:16:34

标签: python scikit-learn

我想更改精度指标的平均参数,因为会出现此错误

  

“ ValueError:目标是多类的,但average ='binary'。请选择   另一个平均设置。”

我已经阅读了官方网站,但是在使用cross_validate函数方面找不到答案。

clf = RandomForestClassifier()
scoring = ['accuracy', 'precision']

scores = cross_validate(clf, X, Y, scoring=scoring, cv=10, return_train_score=False, n_jobs=-1)

有人知道如何处理吗?

1 个答案:

答案 0 :(得分:1)

使用make_scorer,它允许您为各个得分指标指定参数,然后使用字典将多个指标映射到名称:

from sklearn.metrics import accuracy_score, precision_score, make_scorer
scoring = {'Accuracy': make_scorer(accuracy_score), 
           'Precision': make_scorer(precision_score, average='None')}

scores = cross_validate(clf, X, Y, scoring=scoring, ...)

请参阅this示例