如何检索弹性网系数?

时间:2016-10-17 13:51:10

标签: r r-caret

我正在使用插入包来训练我的数据集modDat上的弹性网模型。我采用网格搜索方法与重复交叉验证配对,以选择弹性网络函数所需的λ和分数参数的最佳值。我的代码如下所示。

library(caret)
library(elasticnet)

grid <- expand.grid(
      lambda   = seq(0.5, 0.7, by=0.1),
      fraction = seq(0, 1, by=0.1)
    )

ctrl <- trainControl(
      method     = 'repeatedcv',
      number     = 5,  #folds
      repeats    = 10, #repeats
      classProbs = FALSE
    )

set.seed(1)
enetTune <- train(
          y ~ .,
          data = modDat,
          method = 'enet',
          metric = 'RMSE',
          tuneGrid = grid,
          verbose = FALSE,
          trControl = ctrl
        )

我可以使用y_hat <- predict(enetTune, modDat)获得预测,但我无法查看预测所依据的系数。

我试过coef(enetTune$finalModel),但唯一返回的是NULL。我怀疑我必须向coef()函数提供更多信息,但不知道如何执行此操作。

此外,我想制作一组与最佳λ和分数参数相关的50组系数(10次重复5次)的箱形图。

3 个答案:

答案 0 :(得分:2)

要查看系数,请使用predict

predict(enetTune$finalModel, type = "coefficients")

有关如何获取特定系数的详细信息,请参阅?predict.enet

答案 1 :(得分:2)

根据@Weihuang Wong的回答,您可以使用以下代码从最终模型中获取系数:

predict.enet(enetTune$finalModel, s=enetTune$bestTune[1, "fraction"], type="coef", mode="fraction")$coefficients

答案 2 :(得分:0)

对我来说,最好的方法是stats::predict,就像@Weihuang Wong的答案一样。但是,正如OP在评论中指出的那样,它提供了所测试的每个lambda值的系数列表。

这里要了解的重要一点是,当您使用predict时,您的意图恰恰是预测参数的值,而不是真正地检索它们。然后,您应该知道浏览可用的选项。

在这种情况下,可以对惩罚参数lambda使用与参数s相同的函数。记住您仍在预测中,但是这次您将获得所需的系数。

stats::predict(enetTune$finalModel, type = "coefficients", s = enetTune$bestTune$lambda)