如何创建自定义模型(使用插入符号中的循环/子模型技巧)

时间:2018-08-26 21:48:43

标签: r machine-learning r-caret gbm

我为此困扰了很长时间。我感觉像是一个绝对的克里斯汀,因为答案可能很痛苦地显而易见,但是我找不到一个单一的线程来解释如何做到这一点。

关于自定义模型创建的文档部分对我来说就像this。我感觉就像在我的教育过程中某个地方错过了一个非常具体的课程,每个人现在都记得,但是我,因为我发现的是“是的,只需创建一个自定义模型,然后完成”。

此处的实际问题:

我想获得gbmcaret的每个迭代的预测。例如,在gbm中,我只可以在n.trees中使用predict(..., n.trees = 1:100)

显然,在caret中,我需要使用一种称为子模型的技巧,这意味着-如果我理解正确的话-我必须创建自己的自定义模型。

但是我可以在getModelInfo('gbm')中看到某种循环功能!

$gbm$loop
function (grid) 
{
    loop <- plyr::ddply(grid, c("shrinkage", "interaction.depth", 
        "n.minobsinnode"), function(x) c(n.trees = max(x$n.trees)))
    submodels <- vector(mode = "list", length = nrow(loop))
    for (i in seq(along = loop$n.trees)) {
        index <- which(grid$interaction.depth == loop$interaction.depth[i] & 
            grid$shrinkage == loop$shrinkage[i] & grid$n.minobsinnode == 
            loop$n.minobsinnode[i])
        trees <- grid[index, "n.trees"]
        submodels[[i]] <- data.frame(n.trees = trees[trees != 
            loop$n.trees[i]])
    }
    list(loop = loop, submodels = submodels)

该如何使用?为什么默认情况下不起作用?我是否真的需要创建自定义模型-也许不需要?

免责声明1:我不想使用任何交叉验证。我只想针对单个gbm运行的每次迭代得出预测。

免责声明2:我不想在predict.gbm()上使用$finalModel,因为我还想测试其他算法,这些算法也利用了该子模型的技巧。我不想使用所有不同的算法特定的predict()函数,因为那为什么我还要打扰呢?

作为一个可复制的示例,我什至不知道该写些什么。代码没有问题。我只是不知道这东西应该如何工作。

1 个答案:

答案 0 :(得分:0)

这是一个示例,说明如何为每棵树提取测试数据的期望预测:

library(caret)
library(mlbench) #for the data set
data(Sonar) #some data set I always use on stack overflow

res <- train(Class~.,
             data = Sonar,
             method = "gbm",
             trControl = trainControl(method = "cv", #some evaluations scheme
                                      number = 5,
                                      savePredictions = "all"), #tell caret you would like to save all,
             tuneGrid = expand.grid(shrinkage = 0.01,
                                    interaction.depth = 2, 
                                    n.minobsinnode = 10,
                                    n.trees = 1:100)) #some random values and all the trees

res$pred #results are stored in here

基本上,您在帖子中显示的代码告诉插入符号不要调整所有n.tree模型,而只是对每个超参数组合使用max(n.trees)进行调整,然后使用它来获得{{ 1}}

一些情节

n.trees < max(n.trees)

enter image description here

您还可以选择不选择library(ggplot2) ggplot(res$results)+ geom_line(aes(x = n.trees, y = Accuracy)) ,因为这会导致内存不足的火车对象。而是使用savePredictions = "all"来计算所有所需指标。