我正在使用ROC
绘制多类scikitplot
曲线,并且出现数组索引过多的错误。
我尝试了多种方法,但我一直遇到相同的错误。
这是我的代码:
X_train, X_test, y_train, y_test = train_test_split(x_data, labels, test_size=0.1,random_state=101)
dnn_model = tf.estimator.DNNClassifier(hidden_units=[20,20,20],feature_columns=feat_cols,n_classes=57)
train = dnn_model.train(input_fn=input_func,steps=1000)
eval_input_func=tf.estimator.inputs.pandas_input_fn(x=X_test,y=y_test,batch_size=10,num_epochs=1,shuffle=False)
result = dnn_model.evaluate(eval_input_func)
result
{
'accuracy': 0.96153843,
'average_loss': 0.18984392,
'loss': 1.6453141,
'global_step': 1000
}
pred_input_func=tf.estimator.inputs.pandas_input_fn(x=X_test,batch_size=10,num_epochs=1,shuffle=False)
prediction = dnn_model.predict(pred_input_func)
my_pred = list(prediction)
final_pred = [pred['class_ids'][0] for pred in my_pred]
final_pred
import scikitplot as skplt
import matplotlib.pyplot as plt
skplt.metrics.plot_roc(y_test, final_pred)
plt.show()
IndexError Traceback (most recent call last)
<ipython-input-166-b0aa20ef20be> in <module>
----> 1 skplt.metrics.plot_roc(y_test, final_pred)
2 plt.show()
~\Anaconda2\envs\stroke\lib\site-packages\scikitplot\metrics.py in plot_roc(y_true, y_probas, title, plot_micro, plot_macro, classes_to_plot, ax, figsize, cmap, title_fontsize, text_fontsize)
412 indices_to_plot = np.in1d(classes, classes_to_plot)
413 for i, to_plot in enumerate(indices_to_plot):
--> 414 fpr_dict[i], tpr_dict[i], _ = roc_curve(y_true, probas[:, i],
415 pos_label=classes[i])
416 if to_plot:
IndexError: too many indices for array