R caret包中的列车功能

时间:2017-10-04 18:59:05

标签: r r-caret

假设我有一个数据集,并且我想使用逻辑回归进行4倍交叉验证。所以将有4种不同的型号。在R中,我做了以下事情:

ctrl <- trainControl(method = "repeatedcv", number = 4, savePredictions = TRUE)
mod_fit <- train(outcome ~., data=data1, method = "glm", family="binomial", trControl = ctrl)

我认为 mod_fit 应该包含4组不同的系数?当我输入 modfit$finalModel$ 时,我只得到相同的系数集。

2 个答案:

答案 0 :(得分:4)

我已根据您的代码段创建了一个可重现的示例。关于代码的第一件事就是它指定repeatedcv作为方法,但它没有给出任何repeats,因此number=4参数只是告诉它重新采样4次(这不是你问题的答案,但要理解很重要)。

mod_fit$finalModel只给出一组系数,因为它是通过对4个折叠中的每一个进行非重复k倍CV结果而得出的最终模型。< / p>

您可以在resample对象中看到折叠级别的效果:

library(caret)
library(mlbench)

data(iris)

iris$binary  <- ifelse(iris$Species=="setosa",1,0)
iris$Species <- NULL

ctrl    <- trainControl(method = "repeatedcv", 
                        number = 4, 
                        savePredictions = TRUE,
                        verboseIter = T,
                        returnResamp = "all")

mod_fit <- train(binary ~., 
                 data=iris, 
                 method = "glm", 
                 family="binomial", 
                 trControl = ctrl)


# Fold-level Performance
mod_fit$resample
          RMSE  Rsquared parameter   Resample
1 2.630866e-03 0.9999658      none Fold1.Rep1
2 3.863821e-08 1.0000000      none Fold2.Rep1
3 8.162472e-12 1.0000000      none Fold3.Rep1
4 2.559189e-13 1.0000000      none Fold4.Rep1

对于您之前的观点,该软件包不会保存并显示每个折叠系数的信息。此外,上面的性能信息确实保存了index(样本内行的列表),indexOut(保存行的方式),以及每个折叠的随机种子,因此如果你这么倾向重建中间模型很容易。

mod_fit$control$seeds
[[1]]
[1] 169815

[[2]]
[1] 445763

[[3]]
[1] 871613

[[4]]
[1] 706905

[[5]]
[1] 89408
mod_fit$control$index
$Fold1
  [1]   1   2   3   4   5   6   7   8   9  10  11  12  15  18  19  21  22  24  28  30  31  32  33  34  35  40  41  42  43  44  45  46  47
     

48 49 50 51 52 53 54 59 60 61 63        [45] 64 65 66 68 69 70 71 72 73 75 76 77 79 80 81 82 84 85 86 87 89 90 91 92 93 94 95 96 98 99 100 103 104   106 107 108 110 111 113 114 116 118 119 120        [89] 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 140 141 142 143 145 147 149 150

$Fold2
  [1]   1   6   7   8  12  13  14  15  16  17  18  19  20  21  22  23  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40  42
     

44 46 48 50 51 53 54 55 56 57 58        [45] 59 61 62 64 66 67 69 70 71 72 73 74 75 76 78 79 80 81 82 83 84 85 87 88 89 90 91 92 95 96 97 98 99   101 102 104 105 106 108 109 111 112 113 115        [89] 116 117 119 120 121 122 123 127 130 131 132 134 135 137 138 139 140 141 142 143 144 145 146 147 148

$Fold3
  [1]   2   3   4   5   6   7   8   9  10  11  13  14  16  17  20  23  24  25  26  27  28  29  30  33  35  36  37  38  39  40  41  43  45
     

46 47 49 50 51 52 54 55 56 57 58        [45] 60 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 82 83 84 85 86 88 89 93 94 97 98 99 100 101 102   103 105 106 107 108 109 110 111 112 114 115        [89] 117 118 119 121 124 125 126 128 129 131 132 133 134 135 136 137 138 139 144 145 146 147 148 149 150

$Fold4
  [1]   1   2   3   4   5   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  29  31  32  34  36  37  38  39  41
     

42 43 44 45 47 48 49 52 53 55 56        [45] 57 58 59 60 61 62 63 65 67 68 74 77 78 79 80 81 83 86 87 88 90 91 92 93 94 95 96 97 100 101 102 103 104   105 107 109 110 112 113 114 115 116 117 118        [89] 120 122 123 124 125 126 127 128 129 130 133 136 137 138 139 140 141 142 143 144 146 148 149 150

mod_fit$control$indexOut
$Resample1
 [1]  13  14  16  17  20  23  25  26  27  29  36  37  38  39  55  56  57  58  62  67  74  78  83  88  97 101 102 105 109 112 115 117 137
138 139 144 146 148

$Resample2
 [1]   2   3   4   5   9  10  11  24  41  43  45  47  49  52  60  63  65  68  77  86  93  94 100 103 107 110 114 118 124 125 126 128 129
133 136 149 150

$Resample3
 [1]   1  12  15  18  19  21  22  31  32  34  42  44  48  53  59  61  79  80  81  87  90  91  92  95  96 104 113 116 120 122 123 127 130
140 141 142 143

$Resample4
 [1]   6   7   8  28  30  33  35  40  46  50  51  54  64  66  69  70  71  72  73  75  76  82  84  85  89  98  99 106 108 111 119 121 131
132 134 135 145 147

答案 1 :(得分:1)

@Damien您的mod_fit将不包含4组独立的系数。您要的是cross validation折的4。这并不意味着您将有4种不同的模型。根据文档heretrain函数的工作方式如下:

enter image description here

在重采样循环的末尾-在您的情况下,进行4次迭代4次,对于给定的一组模型参数,您将具有一组平均预测准确性度量(例如rmse,R平方)。

由于您没有在tuneGrid函数中使用tuneLengthtrain参数,因此默认情况下,train函数将在每个可调整参数的三个值上进行调整。

这意味着您最多将拥有三个模型(而不是您期望的四个模型),因此将拥有三组平均模型性能指标。

最优模型是在回归的情况下具有最低均方根值的模型。该模型系数在mod_fit$finalModel中可用。