修改插入符号中的summaryFunction以计算分组的Brier-Score

时间:2017-08-23 21:39:18

标签: r customization r-caret

我想在交叉验证中使用分组的brier分数比较多项Logit模型和随机森林。这种方法的理论基础是:https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3702649/pdf/nihms461154.pdf

我的因变量有三个结果,我的数据集包含生命周期数据,其寿命在0-5之间。

为了使事物具有可重现性,我的数据集看起来像

library(data.table)
N      <- 1000
X1     <- rnorm(N, 175, 7)
X2     <- rnorm(N,  30, 8)
length   <- sample(0:5,N,T)
Ycont  <- 0.5*X1 - 0.3*X2 + 10 + rnorm(N, 0, 6)
Ycateg <- ntile(Ycont,3)
df     <- data.frame(id=1:N,length,X1, X2, Ycateg)
df$Ycateg=ifelse(df$Ycateg==1,"current",ifelse(df$Ycateg==2,"default","prepaid"))

df=setDT(df)[,.SD[rep(1L,length)],by = id]
df=df[ , time := 1:.N , by=id]
df=df[,-c("length")]
head(df)
   id       X1       X2 Ycateg time
1:  1 178.0645 10.84313      1    1
2:  2 169.4208 34.39831      1    1
3:  2 169.4208 34.39831      1    2
4:  2 169.4208 34.39831      1    3
5:  2 169.4208 34.39831      1    4
6:  2 169.4208 34.39831      1    5

到目前为止我做的是

library(caret)
fitControl <- trainControl(method = 'cv',number=5)

cv=train(as.factor(Ycateg)~.,
         data = df,
         method = "multinom",
         maxit=150,
         trControl = fitControl)
cv

由于模型用于预测每个时间点的概率,我想为每个时间点计算以下内容:

  1. 因变量的每个类别的Brier分数:BS_i =(Y_it,k-p_it,k)² - 其中i表示测试折叠的观察i,t表示时间,k表示k的类别k因变量

  2. 通过计算1 / n_t(BS_i)来总结这一倍1.其中n_t是具有观察时间t的观测数量 - 因此分组计算

  3. 所以最后,我要报告的内容 - 例如3倍CV&amp;知道时间范围从0到5 - 是这样的输出:

       fold time    Brier_0    Brier_1   Brier_2
    1     1    0 0.39758714 0.11703814 0.8711775
    2     1    1 0.99461281 0.95051037 0.1503217
    3     1    2 0.01791559 0.83653814 0.1553521
    4     1    3 0.92067849 0.55275340 0.6466206
    5     1    4 0.73112563 0.07603891 0.5769286
    6     1    5 0.29500600 0.66219814 0.7590742
    7     2    0 0.24691469 0.06736522 0.8612998
    8     2    1 0.13629191 0.55973431 0.5617303
    9     2    2 0.48006915 0.01357407 0.4515544
    10    2    3 0.01257112 0.40250469 0.1814620
     .    .    .    .              .       .
    

    我知道我必须设置summaryFunction的自定义版本,但我真的迷失了如何做到这一点。所以我的主要目标不是调整模型,而是验证它。

1 个答案:

答案 0 :(得分:0)

有一件事需要注意:summaryFunction只能返回一个数字向量 - 如果我错了,请纠正我。此外,summaryFunction的数据参数包含一列rowIndex,可用于从原始数据集中提取其他变量。

customSummary <- function (data, lev = NULL, model = NULL) { # for training on a next-period return
  #browser() #essential for debugging
  dat=dim(data)

  # get observed dummy
  Y_obs = model.matrix( ~ data[, "obs"] - 1) # create dummy - for each level of the outcome
  # get predicted probabilities
  Y_pre=as.data.frame(data[ , c("current","default","prepaid")])
  # get rownumbers
  rows=data[,"rowIndex"]
  # get time of each obs
  time=df[rows,]$time
  # put it all together
  df_temp=data.frame(Y_obs,Y_pre,time)
  names(df_temp)=c("Y_cur","Y_def","Y_pre","p_cur","p_def","p_pre","time")
  # group by time and compute crier score
  out=df_temp %>% group_by(time) %>% summarise(BS_cur=1/n()*sum((Y_cur-p_cur)^2),BS_def=1/n()*sum((Y_def-p_def)^2),BS_pre=1/n()*sum((Y_pre-p_pre)^2))
  # name 
  names(out)=c("time","BS_cur","BS_def","BS_pre")
  # now create one line of return - caret seems to be able to hande only one
  out=as.data.frame(out)
  out_stack=stack(out)
  out_stack=out_stack[(max(out$time)):length(out_stack[,1]),]
  out_stack=out_stack[-1,]
  out_stack$ind=paste(out_stack$ind,out$time,sep = "_")

  # recall, the return type must be simply numeric
  out_final=(t(out_stack[,1]))
  names(out_final)=(out_stack[,2])


  return(out_final)
}


# which type of cross validation to do
fitControl <- trainControl(method = 'cv',number=5,classProbs=TRUE,summaryFunction=customSummary, selectionFunction = "best", savePredictions = TRUE)


grid <- expand.grid(decay = 0 )

cv=train(as.factor(Ycateg)~.,
         data = df,
         method = "multinom",
         maxit=150,
         trControl = fitControl,
         tuneGrid = grid
)

cv$resample
  BS_cur_1  BS_cur_2  BS_cur_3  BS_cur_4  BS_cur_5  BS_def_1  BS_def_2  BS_def_3  BS_def_4  BS_def_5  BS_pre_1  BS_pre_2  BS_pre_3  BS_pre_4  BS_pre_5
1 0.1657623 0.1542842 0.1366912 0.1398001 0.2056348 0.1915512 0.2256758 0.2291467 0.2448737 0.2698545 0.1586134 0.2101389 0.1432483 0.2076886 0.1663780
2 0.1776843 0.1919503 0.1615440 0.1654297 0.1200515 0.2108787 0.2185783 0.2209958 0.2467931 0.2199898 0.1580643 0.1595971 0.2015860 0.1826029 0.1947144
3 0.1675981 0.1818885 0.1893253 0.1402550 0.1400997 0.2358501 0.2342476 0.2079819 0.1870549 0.2065355 0.2055711 0.1586077 0.1453172 0.1638555 0.2106146
4 0.1796041 0.1573086 0.1500860 0.1738538 0.1171626 0.2247850 0.2168341 0.2031590 0.1807209 0.2616180 0.1677508 0.1965577 0.1873078 0.1859176 0.1344115
5 0.1909324 0.1640292 0.1556209 0.1371598 0.1566207 0.2314311 0.1991000 0.2255612 0.2195158 0.2071910 0.1976272 0.1777507 0.1843828 0.1453439 0.1736540
  Resample
1    Fold1
2    Fold2
3    Fold3
4    Fold4
5    Fold5