绘制多种算法的精确召回曲线

时间:2020-06-27 20:29:00

标签: python machine-learning scikit-learn

我想为我的文本分类所使用的三种算法绘制一条精确的召回曲线。我是一个初学者,所以有人可以告诉我如何在现有代码中添加该功能。

nb_classifier = MultinomialNB()
svm_classifier = LinearSVC()
lr_classifier = LogisticRegression(multi_class="ovr")
X_train, X_test, y_train, y_test = model_selection.train_test_split(df_train.data, df_train.label, test_size=0.2 , stratify = df_train['label'])
vect = CountVectorizer(stop_words='english', max_features=10000,
                       token_pattern=r'[a-zA-Z]{3,}' , ngram_range=(1,2))
X_train_dtm = vect.fit_transform(X_train)
X_test_dtm = vect.transform(X_test)
nb_classifier.fit(X_train_dtm, y_train)
svm_classifier.fit(X_train_dtm, y_train)
lr_classifier.fit(X_train_dtm, y_train)
nb_predictions = nb_classifier.predict(X_test_dtm)
svm_predictions = svm_classifier.predict(X_test_dtm)
lr_predictions = lr_classifier.predict(X_test_dtm)

1 个答案:

答案 0 :(得分:0)

您可以使用sklearn.metrics的plot_precision_recall_curve绘制以下方法的精度调用曲线:

nb_classifier = MultinomialNB()
svm_classifier = LinearSVC()
lr_classifier = LogisticRegression(multi_class="ovr")
X_train, X_test, y_train, y_test = model_selection.train_test_split(df_train.data, df_train.label, test_size=0.2 , stratify = df_train['label'])
vect = CountVectorizer(stop_words='english', max_features=10000,
                       token_pattern=r'[a-zA-Z]{3,}' , ngram_range=(1,2))
X_train_dtm = vect.fit_transform(X_train)
X_test_dtm = vect.transform(X_test)
nb_classifier.fit(X_train_dtm, y_train)

svm_classifier.fit(X_train_dtm, y_train)
lr_classifier.fit(X_train_dtm, y_train)
nb_predictions = nb_classifier.predict(X_test_dtm)
svm_predictions = svm_classifier.predict(X_test_dtm)
lr_predictions = lr_classifier.predict(X_test_dtm)

#plot Precision-Recall curve and display average precision-recall score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import plot_precision_recall_curve
import matplotlib.pyplot as plt
from sklearn.metrics import average_precision_score

disp = plot_precision_recall_curve(svm_classifier, X_test_dtm, y_test) #display Precision-Recall curve for svm_classifier
average_precision = average_precision_score(y_test, svm_predictions)
print('Average precision-recall score for svm_classifier: {0:0.2f}'.format(
      average_precision))

disp = plot_precision_recall_curve(nb_classifier, X_test_dtm, y_test) #display Precision-Recall curve for nb_classifier
average_precision = average_precision_score(y_test, nb_predictions)
print('Average precision-recall score for nb_classifier: {0:0.2f}'.format(
      average_precision))

disp = plot_precision_recall_curve(lr_classifier, X_test_dtm, y_test) #display Precision-Recall curve for nb_classifier
average_precision = average_precision_score(y_test, lr_predictions)
print('Average precision-recall score for lr_classifier: {0:0.2f}'.format(
      average_precision))