绘制ROC曲线 - 索引错误太多

时间:2014-08-04 18:35:43

标签: python numpy scikit-learn roc

我正在直接从这里获取ROC代码:http://scikit-learn.org/stable/auto_examples/plot_roc.html

我已经在for循环中硬编码了我的类数为46,正如你所看到的,但即使我将它设置为低至2,我仍然会收到错误。

# Compute ROC curve and ROC area for each class
tpr = dict()
roc_auc = dict()
for i in range(46):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

错误是:

Traceback (most recent call last):
  File "C:\Users\app\Documents\Python Scripts\gbc_classifier_test.py", line 150, in <module>
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred[:, i])
IndexError: too many indices
你可以在这里看到

y_predarray.shape() giving error tuple not callable

y_test只是一个类似于y_pred的一维数组,除了我的问题的真实类。

我不明白,什么指数太多了?

1 个答案:

答案 0 :(得分:2)

您的其他问题和y_pred中显示的y_test都是1-d,因此表达式y_pred[:, i]y_test[:, i]的索引太多。您只能使用单个索引索引一维数组。

那就是说,你应该打电话给roc_curve(y_test, y_pred)