使用自定义采样器

时间:2017-09-18 23:45:57

标签: r parallel-processing r-caret

我试图将自己的采样器提供给插入符号包的训练功能(因为数据不平衡),然后在并行环境中训练模型。 如果我不把采样器送到火车上,那就可以了。 如果我将采样器送到列车但不使用并联功能,那么它再次正常工作。 但是,如果我要求它与采样器并行运行,那么它会给我一个错误。我试过在两个不同的系统上运行,结果是一样的但是我在两种情况下得到的错误是不同的。这是一个例子:

library(caret)
set.seed(1)
data(iris)

library(DMwR)
library(doParallel)
cl <- makeCluster(3)
cl <- makeCluster(1) #uncommenting this will make the code work 
print(cl)
registerDoParallel(cl)

smote_wrapper <- list(
        name = "custom_smoting",
        func = function(x, y) {
                #print(dim(x))
                print(length(y))
                data <- cbind(x, data.frame(Class = y))
                #print(table(data$Class))
                print("calling smote")
                final <- SMOTE(Class~., data, perc.over = 50, perc.under = 50)
                print("smote over")
                #print(dim(final))
                final$Class <- as.factor(final$Class)
                print(table(final$Class))
                class_index <- which(colnames(final) == "Class")
                print(paste("dim:", dim(final)))
                result <- list(x = final[,-class_index], y = final$Class)
                result
        },
        first = FALSE
)
data(iris)
control <- trainControl(sampling = smote_wrapper)
model <- train(Species~., iris, method = "svmLinear2", trControl = control)
stopCluster(cl)

在一个系统上,它停止训练模式并给出错误:

Error in { : task 1 failed - "object 'out2' not found

在另一个系统上,它给出了:

Something is wrong; all the Accuracy metric values are missing:
Accuracy       Kappa    
 Min.   : NA   Min.   : NA  
 1st Qu.: NA   1st Qu.: NA  
 Median : NA   Median : NA  
 Mean   :NaN   Mean   :NaN  
 3rd Qu.: NA   3rd Qu.: NA  
 Max.   : NA   Max.   : NA  
 NA's   :3     NA's   :3    
Error: Stopping
In addition: Warning message:
In nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,  :
  There were missing values in resampled performance measures.

也许采样器没有并行工作?

我使用的是Caret的最新CRAN安装(6.0.77),但由于另一个错误(&#34;找不到optimismBoot&#34;)我不得不从github安装最新版本(devtools :: install_github)。

1 个答案:

答案 0 :(得分:2)

看起来您可能需要将包和变量导出到群集

registerDoParallel(cl)
# try these lines
clusterEvalQ(cl, { library(DMwR) })
clusterExport(cl, "smote_wrapper")

在并行模式下,插入符将在每个新工作者的环境中查找包/变量,但如果您不导出它们,它们将无法使用。希望这会有所帮助。