如何使用r中的脱字符号包在最佳调整的超参数的10倍交叉验证中获得每个折的预测?

时间:2019-07-09 10:36:50

标签: r r-caret

我试图使用R中的插入符号包,通过10倍交叉验证和3次重复来运行SVM模型。我想使用最佳调整的超参数来获得每一倍的预测结果。我正在使用以下代码

# Load packages
library(mlbench)
library(caret)

# Load data
data(BostonHousing)

#Dividing the data into train and test set
set.seed(101)
sample <- createDataPartition(BostonHousing$medv, p=0.80, list = FALSE)
train <- BostonHousing[sample,]
test <- BostonHousing[-sample,]

control <- trainControl(method='repeatedcv', number=10, repeats=3, savePredictions=TRUE)
metric <- 'RMSE'

# Support Vector Machines (SVM) 
set.seed(101)
fit.svm <- train(medv~., data=train, method='svmRadial', metric=metric,
                 preProc=c('center', 'scale'), trControl=control)
fit.svm$bestTune
fit.svm$pred 

fit.svm$pred使用超参数的所有组合给我预测。但是我只希望对重复的每10倍平均值使用优化过的超参数进行预测。

1 个答案:

答案 0 :(得分:1)

一种实现目标的方法是使用r中的超级参数对fit.svm$pred进行子集设置,然后通过CV复制汇总所需的度量。我将使用fit.svm$bestTune执行此操作:

dplyr

输出:

library(tidyverse)
library(caret)
fit.svm$pred %>%
  filter(sigma == fit.svm$bestTune$sigma & C == fit.svm$bestTune$C) %>% #subset 
  mutate(fold = gsub("\\..*", "", Resample), #extract fold info from resample info
         rep = gsub(".*\\.(.*)", "\\1", Resample)) %>% #extract replicate info from resample info
  group_by(rep) %>% #group by replicate
  summarise(rmse = RMSE(pred, obs)) #aggregate the desired measure

编辑:如果您不喜欢使用正则表达式,或者只想节省一些键入内容,则可以使用# A tibble: 3 x 2 rep rmse <chr> <dbl> 1 Rep1 4.02 2 Rep2 3.96 3 Rep3 4.06

dplyr::separate

EDIT2:回应评论。将观测值和预测值写入CSV。文件:

fit.svm$pred %>%
  filter(sigma == fit.svm$bestTune$sigma & C == fit.svm$bestTune$C) %>%
  separate(Resample, c("fold", "rep"), "\\.") %>%
  group_by(rep) %>%
  summarise(rmse = RMSE(obs, pred))