根据预测结果创建函数以在data.table中创建组,并计算平均差异置信区间

时间:2018-03-12 17:52:02

标签: r function statistics data.table confidence-interval

您好我有以下data.table样本

n = 100000

DT = data.table(dummy = rbinom(n, 1, 0.4), 
                         observed = 50 + sample.int(52, size = n, replace = TRUE),
                         predicted = sample.int(102, size = n, replace = TRUE))

head(DT)

我需要创建一个基本上创建1到20组的函数(使用ggplot中的cut.number),根据预测结果创建20个升序组。

对于每个组,我需要计算观察结果与观察结果的差异,其中虚拟== 0和虚拟== 1.

例如第1组将是基于预测支出的data.table的最低5%。对于这个组,我需要有一个列来计算观察到的列== 0和虚拟== 1的平均值的差异。

然后,我需要创建具有较低和较高置信区间的2列(使用均值标准误差的差异)。

目前,这里是函数

的代码
#Create function
Table <- function(observed1, observed0, predicted, N_groups = 20) { 
  DT1 = data.table(observed0 = observed0, observed1 = observed1, predicted = predicted) 
  DT1[, score_group := as.integer(cut_number(predicted, n = N_groups))]
  DT2 = DT1[, .(predicted = mean(predicted),
                   X = mean(observed1),
                   Y = mean(observed0),
                   N = .N,
                   std_error = sqrt((var(observed1)/nrow(observed1))+(var(observed0)/nrow(observed0)))),
               keyby = score_group]

  #Standard error of difference in means
  DT2[, `:=`(lower_CI = X - Y + qt(0.025, df = nrow(observed1)-1 + nrow(observed0)-1)*std_error,
                 upper_CI = X - Y + qt(0.975, df = nrow(observed1)-1 + nrow(observed0)-1)*std_error), 
          keyby = score_group]
  return(DT2)
}

这显然是错误的,因为我在我的标准错误和CI计算上得到了NA,并且它没有正确地获取虚拟的行数 - 1和虚拟== 0。它获取了两个(n / 20)的总行数,这不是它应该做的。

这就是我目前使用该功能的方式

Table(DT[dummy == 1, observed], 
  DT[dummy == 0, observed], 
  DT$predicted, 
  N_groups)

我确实需要一个函数和一个data.table解决方案。请帮忙!

score_group  predicted        X        Y      N    std_error lower_CI upper_CI
 1:           1   3.552568 76.40180 76.71253 5764        NA       NA       NA
 2:           2   9.020897 76.59119 76.44898 4929        NA       NA       NA
 3:           3  13.990248 76.50569 76.72836 4922        NA       NA       NA
 4:           4  19.012041 76.71393 76.13349 4817        NA       NA       NA
 5:           5  23.992017 76.63382 76.98256 4760        NA       NA       NA
 6:           6  29.017115 76.27058 76.51243 4908        NA       NA       NA
 7:           7  34.005055 76.45592 76.46684 4946        NA       NA       NA
 8:           8  38.984464 76.25172 76.58368 5085        NA       NA       NA
 9:           9  44.000000 76.13215 76.49909 4949        NA       NA       NA
10:          10  49.510715 76.65609 76.12189 5833        NA       NA       NA
11:          11  55.008938 76.47674 76.41519 4923        NA       NA       NA
12:          12  59.986612 76.32976 76.14253 4855        NA       NA       NA
13:          13  64.958274 76.35325 76.84586 4937        NA       NA       NA
14:          14  70.013035 76.75316 76.79226 4833        NA       NA       NA
15:          15  75.004029 76.62067 76.83602 4964        NA       NA       NA
16:          16  79.960352 76.62068 76.33824 4893        NA       NA       NA
17:          17  85.026246 76.04252 76.35788 4915        NA       NA       NA
18:          18  90.021586 76.78144 76.61000 4818        NA       NA       NA
19:          19  94.974199 76.44890 76.46483 4961        NA       NA       NA
20:          20 100.023657 76.42763 76.89435 4988        NA       NA       NA
Warning messages:
1: In data.table(observed0 = observed0, observed1 = observed1, predicted = predicted) :
  Item 1 is of size 59991 but maximum size is 100000 (recycled leaving remainder of 40009 items)
2: In data.table(observed0 = observed0, observed1 = observed1, predicted = predicted) :
  Item 2 is of size 40009 but maximum size is 100000 (recycled leaving remainder of 19982 items)

0 个答案:

没有答案