如何为随机森林应用交叉验证

时间:2019-05-06 14:37:48

标签: r random-forest cross-validation

我想对随机森林回归进行交叉验证,但实际上我不确定该如何做。到目前为止,这是我的代码:

library(rfUtilities)
# Read Data
base <- readxl::read_xlsx(c:\ File)

# Pull columns to use in the model
base_cl <- select(base, 
                  Id = PLA_WTWPartyID, 
                  Ind =Global_reference_Industry, 
                  Num__Ind =NumInd,
                  Retention = Retention_AL,
                  Limit = Limit_AL,
                  Exposure = Exposure_AL,
                  #RL_Exposure = Risk_level_Exposure,
                  LPremium = Liab_Premuim_AL,
                  Haz_Gp = HazardGp_AL,
                  LPick =Loss_Pick_AL,
                  #RL_LPick = Level_Loss_Pick,
                  Rate = Rate_AL,
                  lob = AL_R,
                  Date = AL_R_Date) 

#Clean Data
base_cl$_Ind[is.na(base_cl$_Ind)] <- "Other"
base_cl$Limit[base_cl$Limit == "0"] <- NA
base_cl$Exposure[base_cl$Exposure == "0"] <- NA

#Remove Rate outliers
base_cl$Rate <- remove_outliers(base_cl$Rate)

base_cl <- base_cl %>%
  filter(lob == "1") %>%
  filter(Date == "1") %>%
  drop_na(Limit)%>%
  drop_na(Exposure) %>%
  drop_na(LPremium) %>%
  drop_na(Retention) %>%
  drop_na(Rate)     
output.forest <- randomForest(Formula_3, base_cl, ntree = 400, keep.forest = T,
                              importance = T, localImp = T, mtry = 6)

print(output.forest)
rf.regression.fit(output.forest)
varImpPlot(output.forest, sort = TRUE)    
RF_CV_2 <- rfcv(trainx = base_cl[, 4:9], trainy = base_cl[[10]], p = .2,
                normalize = T,bootstrap = T, trace = T,step = 3, method = "cv")

在这最后一个错误中

RF <- rf.crossValidation(output.forest, base_cl, p = 0.1, n = 99, seed = NULL,
                         normalize = FALSE, bootstrap = FALSE, trace = FALSE, ntree = 400)
  

sample.int(length(x),size,replace,prob)中的错误:找不到对象'sample.sizes'

...,我不知道如何解决此问题。您能帮我建立一个函数或修复代码以运行交叉验证吗,也许使用k = 5或10。

1 个答案:

答案 0 :(得分:0)

通过Google搜索:

library(tidyverse)

# Build Poisson distributions

p_dat <- map_df(1:10, ~ tibble(
  l = paste(.),
  x = 0:20,
  y = dpois(0:20, .)
))

# Build Normal distributions

n_dat <- map_df(1:10, ~ tibble(
  l = paste(.),
  x = seq(0, 20, by = 0.001),
  y = dnorm(seq(0, 20, by = 0.001), ., sqrt(.))
))

# Use ggplot2 to plot

ggplot(n_dat, aes(x, y, color = factor(l, levels = 1:10))) +
  geom_line() +
  geom_point(data = p_dat, aes(x, y, color = factor(l, levels = 1:10))) +
  labs(color = "Lambda:") +
  theme_minimal()

...我们发现该错误已在2月修复,但是您需要从Github安装开发版本。请访问以下网址查看错误报告和响应:https://github.com/jeffreyevans/rfUtilities/issues/4