如何更改二进制分类的阈值

时间:2015-08-20 21:41:27

标签: r machine-learning classification

我在R中训练了一个gbm模型。由于我试图预测一个非常罕见的情况,我得到了很多误报。我想将正(“好”)案例的阈值从默认值更改为0.7。到目前为止,这是我的代码。

modFit.glm.ml <- train(as.factor(ml.training$one_lease)~., data=ml.training, method = "glm")
confusionMatrix(ml.testing$one_lease, predict(modFit.glm.ml, ml.testing), positive = "Good")

此代码有效但它使用默认截止值。 有人提到这可能与预测功能有关,但我不知道该怎么做。

2 个答案:

答案 0 :(得分:6)

您还没有提供可重现的示例,因此,现在使用虹膜数据集来预测虹膜是否属于setosa类型:

dat <- iris
dat$positive <- as.factor(ifelse(dat$Species == "setosa", "s", "ns"))
library(caret)
mod <- train(positive~Sepal.Length, data=dat, method="glm")

要使用除0.5之外的预测概率的截止值生成混淆矩阵,您可以使用您想要的任何截止值来阈值predict函数返回的概率:

confusionMatrix(table(predict(mod, type="prob")[,"s"] >= 0.25,
                      dat$positive == "s"))
# Confusion Matrix and Statistics
# 
#        
#         FALSE TRUE
#   FALSE    88    3
#   TRUE     12   47
#                                           
#                Accuracy : 0.9             
#                  95% CI : (0.8404, 0.9429)
#     No Information Rate : 0.6667          
#     P-Value [Acc > NIR] : 2.439e-11       
#                                           
#                   Kappa : 0.7847          
#  Mcnemar's Test P-Value : 0.03887         
#                                           
#             Sensitivity : 0.8800          
#             Specificity : 0.9400          
#          Pos Pred Value : 0.9670          
#          Neg Pred Value : 0.7966          
#              Prevalence : 0.6667          
#          Detection Rate : 0.5867          
#    Detection Prevalence : 0.6067          
#       Balanced Accuracy : 0.9100          
#                                           
#        'Positive' Class : FALSE  

答案 1 :(得分:1)

您尚未指定要使用的软件包,因此这是使用mlr的另一种解决方案:

library(mlr)

dat = iris
training.set = seq(1, nrow(iris), by = 2)
test.set = seq(2, nrow(iris), by = 2)
dat$positive = as.factor(ifelse(dat$Species == "setosa", "s", "ns"))
task = makeClassifTask(data = dat, target = "positive")
lrn = makeLearner("classif.glmnet", predict.type = "prob")

mod = train(lrn, task, subset = training.set)
pred = predict(mod, task, subset = test.set)

print(getConfMatrix(pred))

pred = setThreshold(pred, c(s = 1))
print(getConfMatrix(pred))

mlr允许您使用setThreshold显式设置阈值 - 优点是您可以将结果预测与任何衡量性能的函数一起使用,而无需确保正确设置阈值。

mlr教程在分类器校准方面有a whole section,可帮助您找出此阈值的最佳值。