使用train
库中的caret
在R中训练模型时是否可以指定超时?
如果没有,是否存在包含代码的R构造,并且可以在一定时间后终止?
答案 0 :(得分:3)
使用trainControl()
对象配置插入符号选项。它没有指定超时时间的参数。
trainControl()
中对运行时性能影响最大的两个设置是method=
和number=
。插入符号中的默认方法是boot
或引导。除非number
,否则引导方法的默认method="cv"
为25。
因此,带有插入符号的randomForest
运行将执行25次迭代引导样本,这是一个非常慢的过程,尤其是在单处理器线程上运行时。
可以通过R.utils
包中的withTimeout()
函数为R函数指定超时时间。
例如,我们将通过插入符号使用mtcars数据集运行随机林,并执行500次迭代的bootstrap采样以使train()
运行超过15秒。我们将使用withTimeout()
在15秒的CPU时间后停止处理。
data(mtcars)
library(randomForest)
library(R.utils)
library(caret)
fitControl <- trainControl(method = "boot",
number = 500,
allowParallel = FALSE)
withTimeout(
theModel <- train(mpg ~ .,data=mtcars,method="rf",trControl=fitControl)
,timeout=15)
...和输出的第一部分:
> withTimeout(
+ theModel <- train(mpg ~ .,data=mtcars,method="rf",trControl=fitControl)
+ ,timeout=15)
[2018-05-19 07:32:37] TimeoutException: task 2 failed - "reached elapsed time limit" [cpu=15s, elapsed=15s]
caret
表现除了简单地超时caret::train()
功能外,我们还可以使用两种技术来提高caret::train()
的效果,并行处理和调整trainControl()
参数。
parallel
和doParallel()
包,这是一个多步骤的过程。 method="boot"
更改为method="cv"
(k倍交叉验证)并将number=
缩减为3
或5
将显着提高caret::train()
的运行时性能1}}。 总结我之前在Improving Performance of Random Forest with caret::train()中描述的技术,以下代码使用Sonar
数据集来实现与caret
和randomForest
的并行处理。
#
# Sonar example from caret documentation
#
library(mlbench)
library(randomForest) # needed for varImpPlot
data(Sonar)
#
# review distribution of Class column
#
table(Sonar$Class)
library(caret)
set.seed(95014)
# create training & testing data sets
inTraining <- createDataPartition(Sonar$Class, p = .75, list=FALSE)
training <- Sonar[inTraining,]
testing <- Sonar[-inTraining,]
#
# Step 1: configure parallel processing
#
library(parallel)
library(doParallel)
cluster <- makeCluster(detectCores() - 1) # convention to leave 1 core for OS
registerDoParallel(cluster)
#
# Step 2: configure trainControl() object for k-fold cross validation with
# 5 folds
#
fitControl <- trainControl(method = "cv",
number = 5,
allowParallel = TRUE)
#
# Step 3: develop training model
#
system.time(fit <- train(Class ~ ., method="rf",data=Sonar,trControl = fitControl))
#
# Step 4: de-register cluster
#
stopCluster(cluster)
registerDoSEQ()
#
# Step 5: evaluate model fit
#
fit
fit$resample
confusionMatrix.train(fit)
#average OOB error from final model
mean(fit$finalModel$err.rate[,"OOB"])
plot(fit,main="Accuracy by Predictor Count")
varImpPlot(fit$finalModel,
main="Variable Importance Plot: Random Forest")
sessionInfo()