为什么R中的随机森林回归不像我的输入数据?

时间:2016-06-06 05:57:36

标签: r vector dataframe random-forest

我正在尝试一个非常简单的随机森林,如下所示:代码完全是自包含且可运行的。

library(randomForest)
n = 1000

factor=10
x1 = seq(n) + rnorm(n, 0, 150)
y = x1*factor + rnorm(n, 0, 550)

x_data = data.frame(x1)
y_data = data.frame(y)

k=2
for (nfold in seq(k)){
  fold_ids <- cut(seq(1, nrow(x_data)), breaks=k, labels=FALSE)
  id_indices <- which(fold_ids==nfold)
  fold_x <- x_data[id_indices,]
  fold_y <- y_data[id_indices,]
  fold_x_df = data.frame(x=fold_x)
  fold_y_df = data.frame(y=fold_y)
  print(paste("number of rows in fold_x_df is ", nrow(fold_x_df), sep=" "))
  print(paste("number of rows in fold_y_df is ", nrow(fold_y_df), sep=" "))
  rf = randomForest(fold_x_df, fold_y_df, ntree=1000)
  print(paste("mse for fold number  ", " is ", sum(rf$mse)))
}

rf = randomForest(x_data, y_data, ntree=1000)

它给了我一个错误:

...The response has five or fewer unique values.  Are you sure you want to do regression?

我不明白为什么它会给我这个错误。

我检查过这些来源:

Use of randomforest() for classification in R? RandomForest error code https://www.kaggle.com/c/15-071x-the-analytics-edge-competition-spring-2015/forums/t/13383/warning-message-in-random-forest

这些都没有解决我的问题。你可以看一下print语句,里面有5个以上的独特标签。更不用说,我在这里做回归,而不是分类,所以我不确定为什么在错误中使用“label”这个词。

1 个答案:

答案 0 :(得分:1)

问题是将响应作为数据框提供。由于响应必须是一维的,因此它应该是一个向量。以下是如何简化代码以使用data randomForest formula参数与 ## simulation: unchanged (but seed set for reproducibility) library(randomForest) n = 1000 factor=10 set.seed(47) x1 = seq(n) + rnorm(n, 0, 150) y = x1*factor + rnorm(n, 0, 550) ## use a single data frame all_data = data.frame(y, x1) ## define the folds outside the loop fold_ids <- cut(seq(1, nrow(x_data)), breaks = k, labels = FALSE) for (nfold in seq(k)) { id_indices <- which(fold_ids == nfold) ## sprintf can be nicer than paste for "filling in blanks" print(sprintf("number of rows in fold %s is %s", nfold, length(id_indices))) ## just pass the subset of the data directly to randomForest ## no need for extracting, subsetting, putting back in data frames... rf <- randomForest(y ~ ., data = all_data[id_indices, ], ntree = 1000) ## sprintf also allows for formatting ## the %g will use scientific notation if the exponent would be >= 3 print(sprintf("mse for fold %s is %g", nfold, sum(rf$mse))) } 方法完全避免此问题:

removeStickyEvent