使用 sklearn 绘制精度和召回率

时间:2021-01-31 20:15:30

标签: python scikit-learn

我使用自定义 ML 框架创建了一个分类模型。

我有 3 个班级:1、2、3

输入样本:

# y_true, y_pred, and y_scores are lists

print(y_true[0], y_pred[0], y_scores[0])
print(y_true[1], y_pred[1], y_scores[1])
print(y_true[2], y_pred[2], y_scores[2])

1 1 0.6903580037019461
3 3 0.8805178752523366
1 2 0.32107199420078963

使用 sklearn 我可以使用:metrics.classification_report:

metrics.classification_report(y_true, y_pred)

                         precision    recall  f1-score   support

                      1      0.521     0.950     0.673        400
                      2      0.000     0.000     0.000        290
                      3      0.885     0.742     0.807        310

               accuracy                          0.610       1000
              macro avg      0.468     0.564     0.493       1000
           weighted avg      0.482     0.610     0.519       1000

我想生成准确率与召回率的可视化。

但我收到此错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-108-2ebb913a4e4b> in <module>()
----> 1 precision, recall, thresholds = metrics.precision_recall_curve(y_true, y_scores)

1 frames
/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
    534     if not (y_type == "binary" or
    535             (y_type == "multiclass" and pos_label is not None)):
--> 536         raise ValueError("{0} format is not supported".format(y_type))
    537 
    538     check_consistent_length(y_true, y_score, sample_weight)

ValueError: multiclass format is not supported 

我找到了一些例子:

但是如果我已经有了结果,我不太清楚如何对我的数组进行二值化,寻找如何简单地绘制它的指针。

1 个答案:

答案 0 :(得分:0)

precision_recall_curve 有一个参数 pos_label,这是 TP/TN/FP/FN 中“正”类的标签。因此,您可以提取相关概率,然后生成精度/召回点为:

y_pred = model.predict_proba(X)

index = 2  # or 0 or 1; maybe you want to loop?
label = model.classes_[index]  # see below
p, r, t = precision_recall_curve(y_true, y_pred[:, index], pos_label=label)

这里的主要讨厌之处在于您需要通过索引提取 y_pred 的列,但 pos_label 需要实际的类标签。您可以使用 model.classes_ 连接那些。

可能还值得注意的是,新的绘图便利函数 plot_precision_recall_curve 不适用于此:它将模型作为参数,如果它不是二元分类,则会中断。

相关问题