我正在使用Gaussian Naive Bayes来训练Pandas数据帧中的模型,但是在使用precision_recall_curve时我遇到了错误。文档说precision_recall_curve将预测的概率作为输入(至少在我读取时),所以我希望下面的工作(xtrain和xtest分别是736和184行的Pandas数据帧; ytrain / ytest是736和184的系列)分别):
nb = GaussianNB()
nb.fit(xtrain, ytrain)
predicted = nb.predict_proba(xtest)
precision, recall, threshold = precision_recall_curve(ytest, predicted)
我希望上面的工作正常,但是我收到了“IndexError:索引230超出了184的范围”。如果我改为:
predicted = nb.predict(xtest)
precision, recall, threshold = precision_recall_curve(ytest, predicted)
然后它正确执行。 184是xtest和ytest中的行数,但230不是任何这些结构的维度。有人可以解释差异或我应该如何使用precision_recall_curve来实现这个目的吗?
答案 0 :(得分:1)
如果这是二进制分类,请尝试使用以下内容
predicted = nb.predict_proba(xtest)
precision, recall, threshold = precision_recall_curve(ytest, predicted[:,1])