在scikit-learn中使用交叉验证时绘制Precision-Recall曲线

时间:2014-10-27 12:38:26

标签: python scikit-learn

我正在使用交叉验证来评估具有scikit-learn的分类器的性能,我想绘制Precision-Recall曲线。我在scikit-learn的网站上找到an example来绘制PR曲线,但它没有使用交叉验证进行评估。

如何在使用交叉验证时绘制scikit中的Precision-Recall曲线?

我做了以下但我不确定这是否是正确的方法(psudo代码):

for each k-fold:

   precision, recall, _ =  precision_recall_curve(y_test, probs)
   mean_precision += precision
   mean_recall += recall

mean_precision /= num_folds
mean_recall /= num_folds

plt.plot(recall, precision)

您怎么看?

编辑:

它不起作用,因为每次折叠后precisionrecall数组的大小不同。

2 个答案:

答案 0 :(得分:7)

不是在每次折叠后记录精度和召回值,而是在每次折叠后在测试样本上存储预测。接下来,收集所有测试(即袋外)预测并计算精度和召回。

 ## let test_samples[k] = test samples for the kth fold (list of list)
 ## let train_samples[k] = test samples for the kth fold (list of list)

 for k in range(0, k):
      model = train(parameters, train_samples[k])
      predictions_fold[k] = predict(model, test_samples[k])

 # collect predictions
 predictions_combined = [p for preds in predictions_fold for p in preds]

 ## let predictions = rearranged predictions s.t. they are in the original order

 ## use predictions and labels to compute lists of TP, FP, FN
 ## use TP, FP, FN to compute precisions and recalls for one run of k-fold cross-validation

在单次,完整的k-fold交叉验证运行中,预测变量对每个样本进行一次且仅一次预测。给定n个样本,您应该有n个测试预测。

(注意:这些预测与训练预测不同,因为预测器会对每个样本进行预测,而不会事先看到它。)

除非您使用留一交叉验证,否则k-fold交叉验证通常需要对数据进行随机分区。理想情况下,您可以执行重复(以及分层)k折交叉验证。然而,组合来自不同轮次的精确回忆曲线并不是直截了当的,因为不能在精确回忆点之间使用简单的线性插值,这与ROC不同(参见Davis and Goadrich 2006)。

我亲自计算 AUC-PR 使用Davis-Goadrich方法在PR空间中进行插值(随后进行数值积分),并使用重复分层10倍交叉的AUC-PR估计值对比分类器验证

对于一个不错的情节,我展示了一个交叉验证轮次的代表性PR曲线。

当然,还有许多其他评估分类器性能的方法,具体取决于数据集的性质。

例如,如果数据集中(二进制)标签的比例没有偏差(即大约为50-50),则可以使用更简单的ROC分析和交叉验证:

收集每个折叠的预测并构建ROC曲线(如前所述),收集所有TPR-FPR点(即采用所有TPR-FPR元组的并集),然后绘制可能平滑的组合点集。可选地,使用简单的线性插值和用于数值积分的复合梯形方法计算AUC-ROC。

答案 1 :(得分:1)

这是使用交叉验证绘制sklearn分类器的Precision Recall曲线的最佳方法。最好的部分是,它绘制了所有类的PR曲线,因此您也可以获得多条整齐的曲线

from scikitplot.classifiers import plot_precision_recall_curve
import matplotlib.pyplot as plt

clf = LogisticRegression()
plot_precision_recall_curve(clf, X, y)
plt.show()

该功能自动负责交叉验证给定数据集,连接所有折叠预测,并计算每个类别的PR曲线+平均PR曲线。它是一个单行功能,可以为您完成所有工作。

Precision Recall Curves

免责声明:请注意,这会使用我构建的scikit-plot库。