混淆矩阵中的Scikit学习变化阈值

时间:2020-04-04 20:06:57

标签: python machine-learning scikit-learn confusion-matrix

对于二进制分类器,我需要在不同的阈值处具有多个混淆矩阵。

我到处都有查询,但是找不到一个简单的实现方法。

任何人都可以提供一种方法来设置scikit-learn的混淆矩阵阈值吗?

我了解scikit-learn的confusion_matrix使用0.5作为阈值。

model = LogisticRegression(random_state=0).fit(X_train, y_train)
y_pred = model.predict(X_test)
confusion_matrix(y_test, y_pred)
Output: array([[24705,     8],
              [  718,     0]])

谢谢!

1 个答案:

答案 0 :(得分:2)

我想通了,

threshold = 0.2
y_pred = (model.predict_proba(X_test)[:, 1] > threshold).astype('float')
confusion_matrix(y_test, y_pred)

希望这对其他所有人都可以通过一种简单的方法来更改阈值有所帮助!