我想在交叉验证中使用分组的brier分数比较多项Logit模型和随机森林。这种方法的理论基础是:https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3702649/pdf/nihms461154.pdf
我的因变量有三个结果,我的数据集包含生命周期数据,其寿命在0-5之间。
为了使事物具有可重现性,我的数据集看起来像:
library(data.table)
N <- 1000
X1 <- rnorm(N, 175, 7)
X2 <- rnorm(N, 30, 8)
length <- sample(0:5,N,T)
Ycont <- 0.5*X1 - 0.3*X2 + 10 + rnorm(N, 0, 6)
Ycateg <- ntile(Ycont,3)
df <- data.frame(id=1:N,length,X1, X2, Ycateg)
df$Ycateg=ifelse(df$Ycateg==1,"current",ifelse(df$Ycateg==2,"default","prepaid"))
df=setDT(df)[,.SD[rep(1L,length)],by = id]
df=df[ , time := 1:.N , by=id]
df=df[,-c("length")]
head(df)
id X1 X2 Ycateg time
1: 1 178.0645 10.84313 1 1
2: 2 169.4208 34.39831 1 1
3: 2 169.4208 34.39831 1 2
4: 2 169.4208 34.39831 1 3
5: 2 169.4208 34.39831 1 4
6: 2 169.4208 34.39831 1 5
到目前为止我做的是:
library(caret)
fitControl <- trainControl(method = 'cv',number=5)
cv=train(as.factor(Ycateg)~.,
data = df,
method = "multinom",
maxit=150,
trControl = fitControl)
cv
由于模型用于预测每个时间点的概率,我想为每个时间点计算以下内容:
因变量的每个类别的Brier分数:BS_i =(Y_it,k-p_it,k)² - 其中i表示测试折叠的观察i,t表示时间,k表示k的类别k因变量
通过计算1 / n_t(BS_i)来总结这一倍1.其中n_t是具有观察时间t的观测数量 - 因此分组计算
所以最后,我要报告的内容 - 例如3倍CV&amp;知道时间范围从0到5 - 是这样的输出:
fold time Brier_0 Brier_1 Brier_2
1 1 0 0.39758714 0.11703814 0.8711775
2 1 1 0.99461281 0.95051037 0.1503217
3 1 2 0.01791559 0.83653814 0.1553521
4 1 3 0.92067849 0.55275340 0.6466206
5 1 4 0.73112563 0.07603891 0.5769286
6 1 5 0.29500600 0.66219814 0.7590742
7 2 0 0.24691469 0.06736522 0.8612998
8 2 1 0.13629191 0.55973431 0.5617303
9 2 2 0.48006915 0.01357407 0.4515544
10 2 3 0.01257112 0.40250469 0.1814620
. . . . . .
我知道我必须设置summaryFunction
的自定义版本,但我真的迷失了如何做到这一点。所以我的主要目标不是调整模型,而是验证它。
答案 0 :(得分:0)
有一件事需要注意:summaryFunction
只能返回一个数字向量 - 如果我错了,请纠正我。此外,summaryFunction
的数据参数包含一列rowIndex
,可用于从原始数据集中提取其他变量。
customSummary <- function (data, lev = NULL, model = NULL) { # for training on a next-period return
#browser() #essential for debugging
dat=dim(data)
# get observed dummy
Y_obs = model.matrix( ~ data[, "obs"] - 1) # create dummy - for each level of the outcome
# get predicted probabilities
Y_pre=as.data.frame(data[ , c("current","default","prepaid")])
# get rownumbers
rows=data[,"rowIndex"]
# get time of each obs
time=df[rows,]$time
# put it all together
df_temp=data.frame(Y_obs,Y_pre,time)
names(df_temp)=c("Y_cur","Y_def","Y_pre","p_cur","p_def","p_pre","time")
# group by time and compute crier score
out=df_temp %>% group_by(time) %>% summarise(BS_cur=1/n()*sum((Y_cur-p_cur)^2),BS_def=1/n()*sum((Y_def-p_def)^2),BS_pre=1/n()*sum((Y_pre-p_pre)^2))
# name
names(out)=c("time","BS_cur","BS_def","BS_pre")
# now create one line of return - caret seems to be able to hande only one
out=as.data.frame(out)
out_stack=stack(out)
out_stack=out_stack[(max(out$time)):length(out_stack[,1]),]
out_stack=out_stack[-1,]
out_stack$ind=paste(out_stack$ind,out$time,sep = "_")
# recall, the return type must be simply numeric
out_final=(t(out_stack[,1]))
names(out_final)=(out_stack[,2])
return(out_final)
}
# which type of cross validation to do
fitControl <- trainControl(method = 'cv',number=5,classProbs=TRUE,summaryFunction=customSummary, selectionFunction = "best", savePredictions = TRUE)
grid <- expand.grid(decay = 0 )
cv=train(as.factor(Ycateg)~.,
data = df,
method = "multinom",
maxit=150,
trControl = fitControl,
tuneGrid = grid
)
cv$resample
BS_cur_1 BS_cur_2 BS_cur_3 BS_cur_4 BS_cur_5 BS_def_1 BS_def_2 BS_def_3 BS_def_4 BS_def_5 BS_pre_1 BS_pre_2 BS_pre_3 BS_pre_4 BS_pre_5
1 0.1657623 0.1542842 0.1366912 0.1398001 0.2056348 0.1915512 0.2256758 0.2291467 0.2448737 0.2698545 0.1586134 0.2101389 0.1432483 0.2076886 0.1663780
2 0.1776843 0.1919503 0.1615440 0.1654297 0.1200515 0.2108787 0.2185783 0.2209958 0.2467931 0.2199898 0.1580643 0.1595971 0.2015860 0.1826029 0.1947144
3 0.1675981 0.1818885 0.1893253 0.1402550 0.1400997 0.2358501 0.2342476 0.2079819 0.1870549 0.2065355 0.2055711 0.1586077 0.1453172 0.1638555 0.2106146
4 0.1796041 0.1573086 0.1500860 0.1738538 0.1171626 0.2247850 0.2168341 0.2031590 0.1807209 0.2616180 0.1677508 0.1965577 0.1873078 0.1859176 0.1344115
5 0.1909324 0.1640292 0.1556209 0.1371598 0.1566207 0.2314311 0.1991000 0.2255612 0.2195158 0.2071910 0.1976272 0.1777507 0.1843828 0.1453439 0.1736540
Resample
1 Fold1
2 Fold2
3 Fold3
4 Fold4
5 Fold5