使用插入符号进行机器学习:如何指定超时?

时间:2018-05-19 11:04:31

标签: r timeout r-caret

使用train库中的caret在R中训练模型时是否可以指定超时? 如果没有,是否存在包含代码的R构造,并且可以在一定时间后终止?

1 个答案:

答案 0 :(得分:3)

使用trainControl()对象配置插入符号选项。它没有指定超时时间的参数。

trainControl()中对运行时性能影响最大的两个设置是method=number=。插入符号中的默认方法是boot或引导。除非number,否则引导方法的默认method="cv"为25。

因此,带有插入符号的randomForest运行将执行25次迭代引导样本,这是一个非常慢的过程,尤其是在单处理器线程上运行时。

强制超时

可以通过R.utils包中的withTimeout()函数为R函数指定超时时间。

例如,我们将通过插入符号使用mtcars数据集运行随机林,并执行500次迭代的bootstrap采样以使train()运行超过15秒。我们将使用withTimeout()在15秒的CPU时间后停止处理。

data(mtcars)
library(randomForest)
library(R.utils)
library(caret)
fitControl <- trainControl(method = "boot",
                           number = 500,
                           allowParallel = FALSE)

withTimeout(
     theModel <- train(mpg ~ .,data=mtcars,method="rf",trControl=fitControl)
     ,timeout=15)

...和输出的第一部分:

> withTimeout(
+      theModel <- train(mpg ~ .,data=mtcars,method="rf",trControl=fitControl)
+      ,timeout=15)
[2018-05-19 07:32:37] TimeoutException: task 2 failed - "reached elapsed time limit" [cpu=15s, elapsed=15s]

提高caret表现

除了简单地超时caret::train()功能外,我们还可以使用两种技术来提高caret::train()的效果,并行处理和调整trainControl()参数。

  1. 编写R脚本以使用并行处理需要paralleldoParallel()包,这是一个多步骤的过程。
  2. method="boot"更改为method="cv"(k倍交叉验证)并将number=缩减为35将显着提高caret::train()的运行时性能1}}。
  3. 总结我之前在Improving Performance of Random Forest with caret::train()中描述的技术,以下代码使用Sonar数据集来实现与caretrandomForest的并行处理。

    #
    # Sonar example from caret documentation
    #
    
    library(mlbench)
    library(randomForest) # needed for varImpPlot
    data(Sonar)
    #
    # review distribution of Class column
    # 
    table(Sonar$Class)
    library(caret)
    set.seed(95014)
    
    # create training & testing data sets
    
    inTraining <- createDataPartition(Sonar$Class, p = .75, list=FALSE)
    training <- Sonar[inTraining,]
    testing <- Sonar[-inTraining,]
    
    #
    # Step 1: configure parallel processing
    # 
    
    library(parallel)
    library(doParallel)
    cluster <- makeCluster(detectCores() - 1) # convention to leave 1 core for OS 
    registerDoParallel(cluster)
    
    #
    # Step 2: configure trainControl() object for k-fold cross validation with
    #         5 folds
    #
    
    fitControl <- trainControl(method = "cv",
                               number = 5,
                               allowParallel = TRUE)
    
    #
    # Step 3: develop training model
    #
    
    system.time(fit <- train(Class ~ ., method="rf",data=Sonar,trControl = fitControl))
    
    #
    # Step 4: de-register cluster
    #
    stopCluster(cluster)
    registerDoSEQ()
    #
    # Step 5: evaluate model fit 
    #
    fit
    fit$resample
    confusionMatrix.train(fit)
    #average OOB error from final model
    mean(fit$finalModel$err.rate[,"OOB"])
    
    plot(fit,main="Accuracy by Predictor Count")
    varImpPlot(fit$finalModel,
               main="Variable Importance Plot: Random Forest")
    sessionInfo()