为Caret CV创建自定义折叠

时间:2016-11-08 10:53:03

标签: r r-caret cross-validation

我使用插入符号包进行建模和交叉验证

model <- caret::train(mpg  ~ wt
                           + drat
                           + disp
                           + qsec
                           + as.factor(am),
                  data = mtcars,
                  method = "lm",
                  trControl = caret::trainControl(method = "cv",
                                                repeats=5,
                                                returnData =FALSE))

但是,我想通过trainControl传递与我的折叠相关的一组自定义索引。这可以通过IndexOut完成。

model <- caret::train(wt ~  + disp + drat,
                  data = mtcars,     
                  method = "lm",
                   trControl = caret::trainControl(method = "cv",
                                      returnData =FALSE,
                                      index = indicies$train,
                                      indexOut = indicies$test))

我挣扎的是我只想测试mtcars.am==0的mtcars中的行。因此,createFolds的使用不会起作用,因为您无法添加标准。有没有人知道允许将行索引到K-folds的任何其他函数,其中mtcars.am==0可以在创建indicies$test时添加标准?

2 个答案:

答案 0 :(得分:1)

我认为这应该有效。只需使用所需的行索引提供索引。

index = list(which(mtcars$am == 0))

model <- caret::train(
    wt ~  +disp + drat,
    data = mtcars,
    method = "lm",
    trControl = caret::trainControl(
        method = "cv",
        returnData = FALSE,
        index = index
    )
)

index参数是一个列表,因此您可以通过在索引中创建多个嵌套列表来为该列表提供任意数量的迭代。

答案 1 :(得分:0)

谢谢你的帮助。我最后通过修改createFolds的输出而不是最好的示例mtcars来到达那里,因为它是如此小的数据集,但你明白了这一点:

folds<-caret::createFolds(mtcars,k=2)

indicies<-list()

#Create training folds
indicies$train<-lapply(folds,function(x) which(!1:nrow(mtcars) %in% x))

#Create test folds based output "folds" and with criterion added
indicies$test<-lapply(folds,function(x) which(1:nrow(mtcars) %in% x & mtcars[,"am"]==1))