如何显示每个交叉验证折叠的混淆矩阵?

时间:2019-07-11 20:31:05

标签: python scikit-learn cross-validation confusion-matrix

我正在使用Python处理多类分类问题。我已经使用交叉验证(sklearn.model_selection.cross_validate)训练了分类器,并为交叉验证的每一次折叠计算了我需要的指标(精度,准确性,召回率,F1分数)。但是现在我需要为每个折叠保存混淆矩阵。我该怎么办?

我已经尝试使用sklearn.model_selection.cross_val_predict来获取预测类,并使用它和实际类来计算sklearn.metrics.confusion_matrix,但是我只有一个混淆矩阵。

scoring_multiclass = {
            'accuracy': 'accuracy',
            'precision_macro': 'precision_macro',
            'recall_macro': 'recall_macro',
            'f1_macro': 'f1_macro'}

cv = StratifiedKFold(n_splits=4)
cv_results = cross_validate(
    estimator=current_pipe,
    X=X,
    y=y,
    scoring=scoring_multiclass,
    cv=cv,
    n_jobs=-1)

mean_cv_results = mean_scores(cv_results)

results = {
    'precision': mean_cv_results['test_precision_macro'],
    'recall': mean_cv_results['test_recall_macro'],
    'f1': mean_cv_results['test_f1_macro'],
    'accuracy': mean_cv_results['test_accuracy'],}

如何显示每折的混淆矩阵?

0 个答案:

没有答案