指定仅带有尖号的最终模型的选项

时间:2018-11-20 12:31:16

标签: r r-caret bartmachine

上下文

我正在使用caret来拟合和调整模型。通常,最好使用重采样方法(例如交叉验证)找到最佳参数。选择最佳参数后,将使用最佳参数集将最终模型拟合到整个训练数据中。

除了要调整的参数(通过tuneGrid传递)外,还可以将参数传递给train来传递参数给底层算法。

我的问题

是否有任何方法可以指定仅用于最终模型的特定于模型的选项?

为了更加清晰:我确实想拟合所有中间模型(以获得可靠的性能估计),但我想为最终模型拟合不同的参数(除了最佳参数)。

特定用例

假设我想将bartMachine应用于某些数据,然后在生产中使用最终模型。我通常会将调整后的模型保存到磁盘并根据需要加载。但是我只能保存/加载已序列化的bartMachine模型,即我需要通过serialize=TbartMachine传递到caret::train

但这会序列化所有模型,这是非常不切实际的。我真的只需要序列化最终模型。有什么办法吗?

library("caret")
library("bartMachine")
tgrid <- expand.grid(num_trees = 100,
                       k = c(2, 3),
                       alpha = 0.95, 
                       beta = 2,
                       nu =  3)
# The printed log shows that all intermediate models are being serialized
fit <- train(hp ~ ., 
             data=mtcars, 
             method="bartMachine",
             serialize=T,
             tuneGrid=tgrid,
             trControl = trainControl(method="cv", 5, verboseIter=T))

1 个答案:

答案 0 :(得分:0)

要使模型适合整个数据集而不进行参数调整或重新采样,请将火车控制方法修改为无:

tgrid <- expand.grid(num_trees = 100,
                     k = 2,
                     alpha = 0.95, 
                     beta = 2,
                     nu =  3)
fit <- train(hp ~ ., 
             data=mtcars, 
             method="bartMachine",
             serialize=TRUE,
             tuneGrid=tgrid,
             trControl = trainControl(method="none"))

请注意,我已经删除了问题代码中两个k值之一。 否则会出现错误:Only one model should be specified in tuneGrid with no resampling。我建议使用其他k值构建一个单独的模型。

上面的代码提供以下输出:

bartMachine initializing with 100 trees...
bartMachine vars checked...
bartMachine java init...
bartMachine factors created...
bartMachine before preprocess...
bartMachine after preprocess... 11 total features...
bartMachine sigsq estimated...
bartMachine training data finalized...
Now building bartMachine for regression ...
building BART with mem-cache speedup...
Iteration 100/1250  mem: 17.6/477.1MB
Iteration 200/1250  mem: 25.1/477.1MB
Iteration 300/1250  mem: 30.8/477.1MB
Iteration 400/1250  mem: 39.9/477.1MB
Iteration 500/1250  mem: 19/477.1MB
Iteration 600/1250  mem: 59.6/477.1MB
Iteration 700/1250  mem: 39.6/477.1MB
Iteration 800/1250  mem: 79.8/477.1MB
Iteration 900/1250  mem: 119.9/477.1MB
Iteration 1000/1250  mem: 40.7/477.1MB
Iteration 1100/1250  mem: 80.8/477.1MB
Iteration 1200/1250  mem: 121/477.1MB
done building BART in 1.289 sec 

burning and aggregating chains from all threads... done
evaluating in sample data...done
serializing in order to be saved for future R sessions...done

fit$finalModel中将serialize参数设置为TRUE:

fit$finalModel$serialize
[1] TRUE

对于其价值,bartMachine内部的check_serialization函数不会给出任何警告或错误(或任何其他输出):

bartMachine:::check_serialization(fit$finalModel)

我不清楚如何从fit$finalModel中提取序列化的对象。 我假设它存储在fit$finalModel$java_bart_machine中,其中包含一个rJava指针。使用bartMachine所依赖的rJava包可能会获得进一步的了解。

更新: @ antoine-sac在下面的注释中指出“ serialize = T不会导致保存模型,而是将样本序列化到模型中,这意味着在将模型写入磁盘时会保存它们。”