如何在使用scikitplot绘制多类roc曲线时修复“数组的索引过多”

时间:2019-05-03 11:20:58

标签: python matplotlib plot roc scikit-plot

我正在使用ROC绘制多类scikitplot曲线,并且出现数组索引过多的错误。

我尝试了多种方法,但我一直遇到相同的错误。

这是我的代码:

X_train, X_test, y_train, y_test = train_test_split(x_data, labels, test_size=0.1,random_state=101)

dnn_model = tf.estimator.DNNClassifier(hidden_units=[20,20,20],feature_columns=feat_cols,n_classes=57)

train = dnn_model.train(input_fn=input_func,steps=1000)

eval_input_func=tf.estimator.inputs.pandas_input_fn(x=X_test,y=y_test,batch_size=10,num_epochs=1,shuffle=False)

result = dnn_model.evaluate(eval_input_func)

result

{
 'accuracy': 0.96153843,
 'average_loss': 0.18984392,
 'loss': 1.6453141,
 'global_step': 1000
}

pred_input_func=tf.estimator.inputs.pandas_input_fn(x=X_test,batch_size=10,num_epochs=1,shuffle=False)
prediction = dnn_model.predict(pred_input_func)
my_pred = list(prediction)

final_pred = [pred['class_ids'][0] for pred in my_pred]
final_pred

    import scikitplot as skplt
    import matplotlib.pyplot as plt
    skplt.metrics.plot_roc(y_test, final_pred)
    plt.show()

IndexError                                Traceback (most recent call last)
<ipython-input-166-b0aa20ef20be> in <module>
----> 1 skplt.metrics.plot_roc(y_test, final_pred)
      2 plt.show()


~\Anaconda2\envs\stroke\lib\site-packages\scikitplot\metrics.py in plot_roc(y_true, y_probas, title, plot_micro, plot_macro, classes_to_plot, ax, figsize, cmap, title_fontsize, text_fontsize)
    412     indices_to_plot = np.in1d(classes, classes_to_plot)
    413     for i, to_plot in enumerate(indices_to_plot):
--> 414         fpr_dict[i], tpr_dict[i], _ = roc_curve(y_true, probas[:, i],
    415                                                 pos_label=classes[i])
    416         if to_plot:

IndexError: too many indices for array

0 个答案:

没有答案