在R中与k-NN交叉验证时如何构造混淆矩阵

时间:2019-04-11 15:04:01

标签: r r-caret nearest-neighbor rpart

我试图在另一个论坛上问这个问题,但是在没有收到任何答复之后,我在这里重新制定了这个问题,并且对我的问题更加具体。我有一个像这样的数据集:

> head(knnresults)
   ACTIVITY_X ACTIVITY_Y ACTIVITY_Z classification
1:         40         47         62        Feeding
2:         60         74         95       Standing
3:         62         63         88       Standing
4:         60         56         82       Standing
5:         66         61         90       Standing
6:         60         53         80       Standing

classification列具有三个不同的类别FeedingStandingForaging

我现在选择一个最佳的k值,这就是为什么我要使用其他80%作为训练来对20%的数据进行分类的原因。分类基于前三列中的值。显示最高准确度的k值将被选择用于以后的分类分析。

这是我一直在使用的脚本:

library(ISLR)
library(caret)
library(lattice)
library(ggplot2)

# Split the data for cross validation:
indxTrain <- createDataPartition(y = knnresults$classification,p = 0.8,list = FALSE)
training <- knnresults[indxTrain,]
testing <- knnresults[-indxTrain,]

# Run k-NN:
set.seed(400)
ctrl <- trainControl(method="repeatedcv",repeats = 3)
knnFit <- train(classification ~ ., data = training, method = "knn", trControl = ctrl, preProcess = c("center","scale"),tuneLength = 20)
knnFit

#Plotting different k values against accuracy (based on repeated cross validation)
plot(knnFit)

首先,由于我是R语言的新手,所以我深表歉意,我不确定此脚本的合法性。如果发现错误,我将很高兴接受任何纠正建议。

第二,如何基于此代码访问分类混淆矩阵?这对于计算与分类相关的绩效指标很重要。

如果可以,我可以dput()在下面的数据集:

> dput(knnresults)
structure(list(ACTIVITY_X = c(40L, 60L, 62L, 60L, 66L, 60L, 57L, 
54L, 52L, 93L, 80L, 14L, 61L, 51L, 40L, 20L, 21L, 5L, 53L, 48L, 
73L, 73L, 21L, 29L, 63L, 59L, 57L, 51L, 53L, 67L, 72L, 74L, 70L, 
60L, 74L, 85L, 77L, 68L, 58L, 80L, 34L, 45L, 34L, 60L, 75L, 62L, 
66L, 51L, 53L, 48L, 62L, 62L, 57L, 5L, 1L, 12L, 23L, 5L, 4L, 
0L, 13L, 45L, 44L, 31L, 68L, 88L, 43L, 70L, 18L, 83L, 71L, 67L, 
75L, 74L, 49L, 90L, 44L, 64L, 57L, 22L, 29L, 52L, 37L, 32L, 120L, 
45L, 22L, 54L, 30L, 9L, 27L, 14L, 3L, 29L, 12L, 61L, 60L, 29L, 
15L, 7L, 6L, 0L, 2L, 0L, 4L, 1L, 7L, 0L, 0L, 0L, 0L, 0L, 1L, 
23L, 49L, 46L, 8L, 31L, 45L, 60L, 37L, 61L, 52L, 51L, 38L, 86L, 
60L, 41L, 43L, 40L, 42L, 42L, 48L, 64L, 71L, 59L, 0L, 27L, 12L, 
3L, 0L, 0L, 8L, 21L, 6L, 2L, 7L, 4L, 3L, 3L, 46L, 46L, 59L, 53L, 
37L, 44L, 39L, 49L, 37L, 47L, 17L, 36L, 32L, 33L, 26L, 12L, 8L, 
31L, 35L, 27L, 27L, 24L, 17L, 35L, 39L, 28L, 54L, 5L, 0L, 0L, 
0L, 0L, 17L, 22L, 25L, 12L, 0L, 5L, 41L, 51L, 66L, 39L, 32L, 
53L, 43L, 40L, 44L, 45L, 48L, 51L, 41L, 45L, 39L, 46L, 59L, 31L, 
5L, 24L, 18L, 5L, 15L, 13L, 0L, 26L, 0L), ACTIVITY_Y = c(47L, 
74L, 63L, 56L, 61L, 53L, 40L, 41L, 49L, 32L, 54L, 13L, 99L, 130L, 
38L, 14L, 6L, 5L, 94L, 96L, 38L, 43L, 29L, 47L, 66L, 47L, 38L, 
31L, 36L, 35L, 38L, 72L, 54L, 44L, 45L, 51L, 80L, 48L, 39L, 85L, 
42L, 39L, 37L, 75L, 36L, 45L, 32L, 35L, 41L, 26L, 99L, 163L, 
124L, 0L, 0L, 24L, 37L, 0L, 6L, 0L, 29L, 29L, 26L, 27L, 54L, 
147L, 82L, 98L, 12L, 83L, 97L, 104L, 128L, 81L, 42L, 102L, 60L, 
79L, 58L, 15L, 14L, 75L, 75L, 40L, 130L, 40L, 13L, 54L, 42L, 
7L, 10L, 3L, 0L, 15L, 8L, 75L, 55L, 26L, 18L, 1L, 13L, 0L, 0L, 
0L, 1L, 0L, 4L, 0L, 0L, 0L, 0L, 0L, 0L, 17L, 45L, 38L, 10L, 31L, 
52L, 36L, 65L, 97L, 45L, 59L, 49L, 92L, 51L, 34L, 21L, 20L, 29L, 
28L, 22L, 32L, 30L, 86L, 0L, 15L, 7L, 4L, 0L, 0L, 0L, 11L, 3L, 
0L, 1L, 3L, 1L, 0L, 72L, 62L, 98L, 55L, 26L, 39L, 28L, 81L, 20L, 
52L, 12L, 48L, 24L, 40L, 30L, 5L, 6L, 40L, 37L, 33L, 26L, 17L, 
14L, 39L, 27L, 28L, 67L, 0L, 0L, 0L, 0L, 0L, 10L, 12L, 14L, 7L, 
0L, 2L, 39L, 67L, 74L, 28L, 23L, 57L, 34L, 36L, 36L, 37L, 46L, 
43L, 73L, 65L, 31L, 64L, 128L, 17L, 3L, 12L, 17L, 0L, 9L, 7L, 
0L, 17L, 0L), ACTIVITY_Z = c(62L, 95L, 88L, 82L, 90L, 80L, 70L, 
68L, 71L, 98L, 97L, 19L, 116L, 140L, 55L, 24L, 22L, 7L, 108L, 
107L, 82L, 85L, 36L, 55L, 91L, 75L, 69L, 60L, 64L, 76L, 81L, 
103L, 88L, 74L, 87L, 99L, 111L, 83L, 70L, 117L, 54L, 60L, 50L, 
96L, 83L, 77L, 73L, 62L, 67L, 55L, 117L, 174L, 136L, 5L, 1L, 
27L, 44L, 5L, 7L, 0L, 32L, 54L, 51L, 41L, 87L, 171L, 93L, 120L, 
22L, 117L, 120L, 124L, 148L, 110L, 65L, 136L, 74L, 102L, 81L, 
27L, 32L, 91L, 84L, 51L, 177L, 60L, 26L, 76L, 52L, 11L, 29L, 
14L, 3L, 33L, 14L, 97L, 81L, 39L, 23L, 7L, 14L, 0L, 2L, 0L, 4L, 
1L, 8L, 0L, 0L, 0L, 0L, 0L, 1L, 29L, 67L, 60L, 13L, 44L, 69L, 
70L, 75L, 115L, 69L, 78L, 62L, 126L, 79L, 53L, 48L, 45L, 51L, 
50L, 53L, 72L, 77L, 104L, 0L, 31L, 14L, 5L, 0L, 0L, 8L, 24L, 
7L, 2L, 7L, 5L, 3L, 3L, 85L, 77L, 114L, 76L, 45L, 59L, 48L, 95L, 
42L, 70L, 21L, 60L, 40L, 52L, 40L, 13L, 10L, 51L, 51L, 43L, 37L, 
29L, 22L, 52L, 47L, 40L, 86L, 5L, 0L, 0L, 0L, 0L, 20L, 25L, 29L, 
14L, 0L, 5L, 57L, 84L, 99L, 48L, 39L, 78L, 55L, 54L, 57L, 58L, 
66L, 67L, 84L, 79L, 50L, 79L, 141L, 35L, 6L, 27L, 25L, 5L, 17L, 
15L, 0L, 31L, 0L), classification = c("Feeding", "Standing", 
"Standing", "Standing", "Standing", "Standing", "Feeding", "Feeding", 
"Feeding", "Standing", "Standing", "Foraging", "Standing", "Standing", 
"Feeding", "Foraging", "Foraging", "Foraging", "Standing", "Standing", 
"Standing", "Standing", "Feeding", "Feeding", "Standing", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Standing", "Standing", 
"Standing", "Feeding", "Standing", "Standing", "Standing", "Standing", 
"Feeding", "Standing", "Feeding", "Feeding", "Feeding", "Standing", 
"Standing", "Feeding", "Feeding", "Feeding", "Feeding", "Feeding", 
"Standing", "Standing", "Standing", "Foraging", "Foraging", "Foraging", 
"Feeding", "Foraging", "Foraging", "Foraging", "Foraging", "Feeding", 
"Feeding", "Feeding", "Standing", "Standing", "Standing", "Standing", 
"Foraging", "Standing", "Standing", "Standing", "Standing", "Standing", 
"Feeding", "Standing", "Feeding", "Standing", "Standing", "Foraging", 
"Foraging", "Standing", "Feeding", "Feeding", "Standing", "Feeding", 
"Foraging", "Feeding", "Feeding", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Standing", "Standing", "Feeding", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Feeding", "Feeding", 
"Foraging", "Feeding", "Feeding", "Feeding", "Feeding", "Standing", 
"Feeding", "Feeding", "Feeding", "Standing", "Standing", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Feeding", "Feeding", 
"Standing", "Standing", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Standing", "Feeding", 
"Standing", "Feeding", "Feeding", "Feeding", "Feeding", "Standing", 
"Feeding", "Feeding", "Foraging", "Feeding", "Feeding", "Feeding", 
"Feeding", "Foraging", "Foraging", "Feeding", "Feeding", "Feeding", 
"Feeding", "Foraging", "Foraging", "Feeding", "Feeding", "Feeding", 
"Standing", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Feeding", "Standing", "Standing", "Feeding", "Feeding", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Feeding", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Standing", "Feeding", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging")), row.names = c(NA, -215L
), class = c("data.table", "data.frame"), .internal.selfref = <pointer: 0x0000000002531ef0>)

任何输入都值得赞赏!

1 个答案:

答案 0 :(得分:1)

以下是可重现的示例:

library(caret)
train_set<-createDataPartition(iris$Species,p=0.8,list=FALSE)
valid_set<-iris[-train_set,]
train_set<-iris[train_set,]
ctrl<-trainControl(method="cv",number=5)
set.seed(233)
mk<-train(Species~.,data=train_set,method="knn",trControl = ctrl,metric="Accuracy")

获取混乱矩阵。理想情况下,最好将您的训练与测试集或验证集的predict值进行比较。

修改: 要检索表,只需执行以下操作:

confusionMatrix(mk)["table"]
$table
            Reference
Prediction       setosa versicolor  virginica
  setosa     33.3333333  0.0000000  0.0000000
  versicolor  0.0000000 32.5000000  2.5000000
  virginica   0.0000000  0.8333333 30.8333333

原始

 confusionMatrix(mk)

结果:

Cross-Validated (5 fold) Confusion Matrix 

(entries are percentual average cell counts across resamples)

            Reference
Prediction   setosa versicolor virginica
  setosa       33.3        0.0       0.0
  versicolor    0.0       31.7       1.7
  virginica     0.0        1.7      31.7

 Accuracy (average) : 0.9667