我正在尝试使用AUPRC作为我的gbm模型拟合的自定义指标,因为我有不平衡的分类器。但是,当我尝试合并自定义指标时,我收到代码中提到的以下错误。不确定我做错了什么。
当我运行内联时,auprcSummary()也可以自行运行。当我尝试将它合并到train()时,它给了我一个错误。
library(dplyr) # for data manipulation
library(caret) # for model-building
library(pROC) # for AUC calculations
library(PRROC) # for Precision-Recall curve calculations
auprcSummary <- function(data, lev = NULL, model = NULL){
index_class2 <- data$Class == "Class2"
index_class1 <- data$Class == "Class1"
the_curve <- pr.curve(data$Class[index_class2],
data$Class[index_class1],
curve = FALSE)
out <- the_curve$auc.integral
names(out) <- "AUPRC"
out
}
ctrl <- trainControl(method = "repeatedcv",
number = 10,
repeats = 5,
summaryFunction = auprcSummary,
classProbs = TRUE)
set.seed(5627)
orig_fit <- train(Class ~ .,
data = toanalyze.train,
method = "gbm",
verbose = FALSE,
metric = "AUPRC",
trControl = ctrl)
这是我得到的错误:
Error in order(scores.class0) : argument 1 is not a vector
是因为pr.curve()只将数字向量作为输入(得分/概率?)
答案 0 :(得分:1)
我认为这种方法产生了一个合适的自定义汇总函数:
library(caret)
library(pROC)
library(PRROC)
library(mlbench) #for the data set
data(Ionosphere)
{p}在pr.curve
函数中,可以为每个类的数据点单独提供分类分数,即,对于来自positive / foreground类的数据点的scores.class0
和{{{ 1}}表示负数/背景类的数据点;或者所有数据点的分类分数都以scores.class1
的形式提供,标签以数值(正类为1,负类为0)提供为scores.class0
(我从帮助中复制了这一点)如果不清楚,我道歉的功能)。
我选择在weights.class0
中为scores.class0
和课程作业中的所有内容提供后期概率。
插入符号表示如果trainControl对象的classProbs参数设置为TRUE,则将出现包含类概率的数据中的其他列。因此,对于weights.class0
数据列,Ionosphere
和good
应该存在:
bad
转换为0/1标签可以做:
levels(Ionosphere$Class)
#output
[1] "bad" "good"
as.numeric(Ionosphere$Class) - 1
将成为good
1
将成为bad
现在我们拥有自定义功能的所有数据
0
不使用仅适用于此数据集的auprcSummary <- function(data, lev = NULL, model = NULL){
prob_good <- data$good #take the probability of good class
the_curve <- pr.curve(scores.class0 = prob_good,
weights.class0 = as.numeric(data$obs)-1, #provide the class labels as 0/1
curve = FALSE)
out <- the_curve$auc.integral
names(out) <- "AUPRC"
out
}
,而是可以提取类名并使用它来获取所需的列:
data$good
重要的是要注意每次更新summaryFunction时都需要更新trainControl对象。
lvls <- levels(data$obs)
prob_good <- data[,lvls[2]]
似乎合理
答案 1 :(得分:1)
caret
有一个名为prSummary
的内置函数,可以为您计算。你不必自己编写。