我正在使用scikit对垃圾邮件/火腿数据执行逻辑回归。 X_train是我的训练数据和y_train标签('垃圾邮件'或'火腿'),我用这种方式训练我的LogisticRegression:
classifier = LogisticRegression()
classifier.fit(X_train, y_train)
如果我想获得10倍交叉验证的准确度,我只想写:
accuracy = cross_val_score(classifier, X_train, y_train, cv=10)
我认为通过这种方式简单地添加一个参数也可以计算精度和召回率:
precision = cross_val_score(classifier, X_train, y_train, cv=10, scoring='precision')
recall = cross_val_score(classifier, X_train, y_train, cv=10, scoring='recall')
但它导致ValueError
:
ValueError: pos_label=1 is not a valid label: array(['ham', 'spam'], dtype='|S4')
它与数据有关(我应该对标签进行二值化吗?)还是更改cross_val_score
函数?
提前谢谢!
答案 0 :(得分:13)
要计算召回率和精确度,必须以这种方式对数据进行二值化:
from sklearn import preprocessing
lb = preprocessing.LabelBinarizer()
lb.fit(y_train)
为了更进一步,我感到惊讶的是,当我想计算准确度时,我没有必要对数据进行二值化:
accuracy = cross_val_score(classifier, X_train, y_train, cv=10)
这只是因为准确度公式并不真正需要关于哪个类被认为是正面还是负面的信息:(TP + TN)/(TP + TN + FN + FP)。我们确实可以看到TP和TN是可交换的,回忆,精度和f1都不是这样。
答案 1 :(得分:3)
我在这里遇到了同样的问题,我用
解决了# precision, recall and F1
from sklearn.preprocessing import LabelBinarizer
lb = LabelBinarizer()
y_train = np.array([number[0] for number in lb.fit_transform(y_train)])
recall = cross_val_score(classifier, X_train, y_train, cv=5, scoring='recall')
print('Recall', np.mean(recall), recall)
precision = cross_val_score(classifier, X_train, y_train, cv=5, scoring='precision')
print('Precision', np.mean(precision), precision)
f1 = cross_val_score(classifier, X_train, y_train, cv=5, scoring='f1')
print('F1', np.mean(f1), f1)
答案 2 :(得分:2)
您在上面显示的语法是正确的。看起来您正在使用的数据存在问题。标签不需要二值化,只要它们不是连续数字。
您可以使用不同的数据集证明相同的语法:
iris = sklearn.dataset.load_iris()
X_train = iris['data']
y_train = iris['target']
classifier = LogisticRegression()
classifier.fit(X_train, y_train)
print cross_val_score(classifier, X_train, y_train, cv=10, scoring='precision')
print cross_val_score(classifier, X_train, y_train, cv=10, scoring='recall')
答案 3 :(得分:1)
您可以使用这样的交叉验证来获得f1分数并回忆:
print('10-fold cross validation:\n')
start_time = time()
scores = cross_validation.cross_val_score(clf, X,y, cv=10, scoring ='f1')
recall_score=cross_validation.cross_val_score(clf, X,y, cv=10, scoring ='recall')
print(label+" f1: %0.2f (+/- %0.2f) [%s]" % (scores.mean(), scores.std(), 'DecisionTreeClassifier'))
print("---Classifier %s use %s seconds ---" %('DecisionTreeClassifier', (time() - start_time)))
了解更多评分参数,请参阅the page
答案 4 :(得分:0)
您应该指定两个标签中的哪一个是正的(可能是火腿):
expect(onPostJobSubmitSpy).toHaveBeenCalled();