Caret函数predict.gb()中可能存在错误?

时间:2015-02-05 05:16:03

标签: r predict r-caret gbm

在我看来,我在R的Caret包中发现了method = gbm的predict()函数的性能错误。我很想知道其他人是否同意,或者是否有人有对此函数行为的解释。

1。生成数据

library(caret)

x1 <- rnorm(100)

x2 <- rnorm(100, 2)

y <- x1 + x2 + rnorm(100)

df <- data.frame(x1=x1, x2=x2,  y=y)

2。预测使用方法=&#34; lm&#34;

以下代码按预期工作:using method =“lm”两个预测值匹配。在第一种情况下,p1,“y”包含在newdata中,在第二种情况下,p2,它不是。

tempd <- df[1:99, c("y", "x1", "x2") ]

newdata <- df[100, c("y", "x1", "x2")]

lm.fit <- train(y~x1 + x2, data=tempd, method="lm")

p1 <- predict(lm.fit$finalModel, newdata=newdata)

newdata <- df[100, c("x1", "x2")]

p2 <- predict(lm.fit$finalModel, newdata=newdata)

p1应该等于p2,并且确实:

p1==p2

第3。预测使用方法=&#34; gbm&#34;

此代码无法正常工作:使用method =“gbm”,设置相同,两个预测值不匹配。

tempd <- df[1:99, c("y","x1","x2")]

newdata <- df[100, c("y","x1","x2")]

gbm.fit <- train(y~x1+x2 , data=tempd, method="gbm", verbose=F)

p1 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                       
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

newdata <- df[100, c("x1","x2")]

p2 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

在这种情况下,p1不等于p2:

p1==p2

4。预测使用方法=&#34; gbm&#34;使用不同的设置

但是,奇怪的是,只有一个小的改变 - 没有明确命名子集操作中的变量 - 它确实有效:

tempd <- df[1:99, ]

newdata <- df[100, ]

gbm.fit <- train(y~x1+x2 , data=tempd, method="gbm", verbose=F)

p1 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                                         
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

newdata <- df[100, c("x1","x2")]

p2 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

p1==p2

提前感谢我们的想法。

杰夫

1 个答案:

答案 0 :(得分:2)

正如@Pascal指出的那样,你正在跳过一个重要的步骤。您应该直接在predict()对象上调用predict,而不是在finalModel值上调用gmb.fit。注意

class(gbm.fit)
# [1] "train"         "train.formula"
class(gbm.fit$finalModel)
# [1] "gbm"

由于这些对象具有不同的类,因此它们会触发不同的基础预测功能。重要的是,predict.trainnewdata重新整形为gbm预测变量的正确格式。如果没有这些数据重新整形,您将得到不正确的结果(预测器希望列按特定顺序排列)

观察

newdata1 <- df[100, c("y","x1","x2")]
newdata2 <- df[100, c("x1","x2")]
newdata3 <- df[100, ]

predict(gbm.fit, newdata1)
# [1] 1.427069
predict(gbm.fit, newdata2)
# [1] 1.427069
predict(gbm.fit, newdata3)
# [1] 1.427069

predict(gbm.fit$finalModel, newdata=newdata1,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 2.166468
predict(gbm.fit$finalModel, newdata=newdata2,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 1.427069
predict(gbm.fit$finalModel, newdata=newdata3,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 1.427069

因此,如果您要使用train()函数来拟合模型,请务必使用正确的predict.train函数从模型中正确进行预测。