我想使代码可重复,并在循环中使用seeds
参数和createMultiFolds
。
我设置了以下代码:
cv_model <- function(dat, targets){
library(randomForest)
library(caret)
library(MLmetrics)
library(Metrics)
results <<- list(weight = NA, vari = NA)
# set up error measures
sumfct <- function(data, lev = NULL, model = NULL){
mape <- MLmetrics::MAPE(y_pred = data$pred, y_true = data$obs)
RMSE <- sqrt(mean((data$pred - data$obs)^2, na.omit = TRUE))
c(MAPE = mape, RMSE = RMSE)
}
for (i in 1:length(targets)) {
set.seed(43)
folds <- caret::createMultiFolds(y = dat$weight,
k = 3,
times = 3)
set.seed(43)
myseeds <- vector(mode = "list", length = 3*3+1)
for (i in 1:9) {
myseeds[[i]] <- sample.int(n=1000, 1)
}
# for the final model
myseeds[[10]] <- sample.int(n=1000, 1)
# specifiy trainControl
control <- caret::trainControl(method="repeatedcv", number=3, repeats=3, search="grid",
savePred =T,
summaryFunction = sumfct, index = folds, seeds = myseeds)
# fixed mtry
params <- data.frame(mtry = 2)
# choose predictor columns by excluding target columns
preds <- dat[, -c(which(names(dat) == "Time"),
which(names(dat) == "Chick"),
which(names(dat) == "Diet"))]
# set target variables
response <- dat[, which(names(dat) == targets[i])]
set.seed(42)
model <- caret::train(x = preds,
y = response,
data = dat,
method="rf",
ntree = 25,
metric= "RMSE",
tuneGrid=params,
trControl=control)
results[[i]] <<- model
}
}
targets <- c("weight", "vari")
dat <- as.data.frame(ChickWeight)
# generate random numbers
set.seed(1)
dat$vari <- c(runif(nrow(dat)))
## use 2 of the cores
library(doParallel)
cl <- makePSOCKcluster(2)
registerDoParallel(cl)
# use function
cv_model(dat = dat, targets = targets)
# end parallel computing
stopCluster(cl)
# unregister doParallel by registering DoSeq (do sequential)
registerDoSEQ()
运行代码后,出现错误消息Error: Please make sure 'y' is a factor or numeric value..
。
如果删除以下几行
set.seed(43)
myseeds <- vector(mode = "list", length = 3*3+1)
for (i in 1:9) {
myseeds[[i]] <- sample.int(n=1000, 1)
}
# for the final model
myseeds[[10]] <- sample.int(n=1000, 1)
在trainControl
, seeds = myseeds
中,然后代码运行而没有错误消息。
如何解决该错误,同时在代码中提供seeds
和createMultiFolds
?