在Caret的最终火车功能模型中的误分类样本

时间:2017-10-31 19:33:36

标签: r random-forest cross-validation r-caret

Caret包中的train函数返回最终模型,我想在主数据框中找到错误分类样本的行索引。我按以下方式进行交叉验证:

library(caret)
train_control <- trainControl(method="cv", number=5,savePredictions =  TRUE,classProbs = TRUE)
output <- train(Species~., data=iris, trControl=train_control, method="rf")

然后最终的模型将是:

> output$finalModel
Call:
randomForest(x = x, y = y, mtry = param$mtry) 
           Type of random forest: classification
                 Number of trees: 500
No. of variables tried at each split: 4

OOB estimate of  error rate: 4.67%
Confusion matrix:
             setosa versicolor virginica class.error
setosa         50          0         0        0.00
versicolor      0         47         3        0.06
virginica       0          4        46        0.08

有没有办法找出哪些样本被错误分类? (上面的混淆矩阵中的3和4个样本)

2 个答案:

答案 0 :(得分:1)

另一种简单的方法是检查预测的样本:

output$output$finalModel$predicted

然后,您可以将预测的与虹膜主数据进行比较

答案 1 :(得分:0)

试试这个:

library(dplyr)
output$pred %>% filter_("pred!=obs")

输出:

         pred        obs setosa versicolor virginica rowIndex mtry Resample
1   virginica versicolor      0      0.084     0.916       71    2    Fold1
2  versicolor  virginica      0      0.976     0.024      107    2    Fold1
3   virginica versicolor      0      0.074     0.926       71    3    Fold1
4  versicolor  virginica      0      0.990     0.010      107    3    Fold1
5  versicolor  virginica      0      0.504     0.496      130    3    Fold1
6   virginica versicolor      0      0.070     0.930       71    4    Fold1
7  versicolor  virginica      0      0.992     0.008      107    4    Fold1
8  versicolor  virginica      0      0.550     0.450      130    4    Fold1
9   virginica versicolor      0      0.244     0.756       78    2    Fold2
10  virginica versicolor      0      0.172     0.828       78    3    Fold2
11  virginica versicolor      0      0.196     0.804       78    4    Fold2
12 versicolor  virginica      0      0.922     0.078      120    2    Fold3
13 versicolor  virginica      0      0.616     0.384      135    2    Fold3
14 versicolor  virginica      0      0.928     0.072      120    3    Fold3
15 versicolor  virginica      0      0.612     0.388      135    3    Fold3
16 versicolor  virginica      0      0.930     0.070      120    4    Fold3
17 versicolor  virginica      0      0.566     0.434      135    4    Fold3
18  virginica versicolor      0      0.352     0.648       84    2    Fold5
19  virginica versicolor      0      0.316     0.684       84    3    Fold5
20  virginica versicolor      0      0.256     0.744       84    4    Fold5

请注意,mtry是在每次拆分时随机抽样为候选项的变量数,Resample列出了交叉验证折叠。

让我们绘制错误分类的项目:

d <- output$pred %>% 
  filter_("pred!=obs") %>% 
  distinct(rowIndex) %>% 
  unlist() %>% sort()

print(unname(d))
# 71  78  84 107 120 130 134 135 139

ggplot(iris, aes(Sepal.Length, Sepal.Width, colour = Species)) + 
  geom_point() + 
  geom_point(data = iris[d, ], aes(x = Sepal.Length, y = Sepal.Width), 
             color = "black")

ggplot(iris, aes(Petal.Length, Petal.Width, colour = Species)) + 
  geom_point() + 
  geom_point(data = iris[d, ], aes(x = Petal.Length, y = Petal.Width), 
             color = "black")

Sepal.Length ~ Sepal.Width

Petal.Length ~ Petal.Width

可以看出,这些图给出了我们结果的直观解释。