Caret包中的train函数返回最终模型,我想在主数据框中找到错误分类样本的行索引。我按以下方式进行交叉验证:
library(caret)
train_control <- trainControl(method="cv", number=5,savePredictions = TRUE,classProbs = TRUE)
output <- train(Species~., data=iris, trControl=train_control, method="rf")
然后最终的模型将是:
> output$finalModel
Call:
randomForest(x = x, y = y, mtry = param$mtry)
Type of random forest: classification
Number of trees: 500
No. of variables tried at each split: 4
OOB estimate of error rate: 4.67%
Confusion matrix:
setosa versicolor virginica class.error
setosa 50 0 0 0.00
versicolor 0 47 3 0.06
virginica 0 4 46 0.08
有没有办法找出哪些样本被错误分类? (上面的混淆矩阵中的3和4个样本)
答案 0 :(得分:1)
另一种简单的方法是检查预测的样本:
output$output$finalModel$predicted
然后,您可以将预测的与虹膜主数据进行比较
答案 1 :(得分:0)
试试这个:
library(dplyr)
output$pred %>% filter_("pred!=obs")
输出:
pred obs setosa versicolor virginica rowIndex mtry Resample
1 virginica versicolor 0 0.084 0.916 71 2 Fold1
2 versicolor virginica 0 0.976 0.024 107 2 Fold1
3 virginica versicolor 0 0.074 0.926 71 3 Fold1
4 versicolor virginica 0 0.990 0.010 107 3 Fold1
5 versicolor virginica 0 0.504 0.496 130 3 Fold1
6 virginica versicolor 0 0.070 0.930 71 4 Fold1
7 versicolor virginica 0 0.992 0.008 107 4 Fold1
8 versicolor virginica 0 0.550 0.450 130 4 Fold1
9 virginica versicolor 0 0.244 0.756 78 2 Fold2
10 virginica versicolor 0 0.172 0.828 78 3 Fold2
11 virginica versicolor 0 0.196 0.804 78 4 Fold2
12 versicolor virginica 0 0.922 0.078 120 2 Fold3
13 versicolor virginica 0 0.616 0.384 135 2 Fold3
14 versicolor virginica 0 0.928 0.072 120 3 Fold3
15 versicolor virginica 0 0.612 0.388 135 3 Fold3
16 versicolor virginica 0 0.930 0.070 120 4 Fold3
17 versicolor virginica 0 0.566 0.434 135 4 Fold3
18 virginica versicolor 0 0.352 0.648 84 2 Fold5
19 virginica versicolor 0 0.316 0.684 84 3 Fold5
20 virginica versicolor 0 0.256 0.744 84 4 Fold5
请注意,mtry
是在每次拆分时随机抽样为候选项的变量数,Resample
列出了交叉验证折叠。
让我们绘制错误分类的项目:
d <- output$pred %>%
filter_("pred!=obs") %>%
distinct(rowIndex) %>%
unlist() %>% sort()
print(unname(d))
# 71 78 84 107 120 130 134 135 139
ggplot(iris, aes(Sepal.Length, Sepal.Width, colour = Species)) +
geom_point() +
geom_point(data = iris[d, ], aes(x = Sepal.Length, y = Sepal.Width),
color = "black")
ggplot(iris, aes(Petal.Length, Petal.Width, colour = Species)) +
geom_point() +
geom_point(data = iris[d, ], aes(x = Petal.Length, y = Petal.Width),
color = "black")
可以看出,这些图给出了我们结果的直观解释。