根据我的数据创建3列预测概率的混淆矩阵

时间:2019-08-17 15:50:50

标签: r

我正在尝试创建一个混淆矩阵。

我的数据如下:

     class    Growth  Negative   Neutral
1   Growth 0.3082588 0.2993632 0.3923780
2  Neutral 0.4696949 0.2918042 0.2385009
3 Negative 0.3608549 0.2679748 0.3711703
4  Neutral 0.3636836 0.2431433 0.3931730
5   Growth 0.4325862 0.2011520 0.3662619
6 Negative 0.2939859 0.2397171 0.4662970

其中class是“真实”的模糊结果,而GrowthNegativeNeutral是模型预测的在任何这些类别中的概率。即,在第一行中,Neutral的结果为0.3923780,因此模型在实际为Growth时会错误地预测此类。

我通常会使用confusionMatrix()中的caret函数,但是我的数据略有不同。我是否应该创建一个名为pred_class的新列,以放置价值最高的列?像这样的东西:

     class    Growth  Negative   Neutral   pred_class
1   Growth 0.3082588 0.2993632 0.3923780    Neutral
2  Neutral 0.4696949 0.2918042 0.2385009    Growth
3 Negative 0.3608549 0.2679748 0.3711703    Neutral
4  Neutral 0.3636836 0.2431433 0.3931730    Neutral
5   Growth 0.4325862 0.2011520 0.3662619    Growth
6 Negative 0.2939859 0.2397171 0.4662970    Neutral

然后我可以做类似confusionMatrix(df$pred_class, df$class)的事情。如何编写函数以根据最高概率将列名称粘贴到列中?

数据:

df <- structure(list(class = c("Growth", "Neutral", "Negative", "Neutral", 
"Growth", "Negative", "Neutral", "Neutral", "Neutral", "Neutral", 
"Neutral", "Negative", "Neutral", "Growth", "Growth", "Growth", 
"Negative", "Negative", "Growth", "Negative"), Growth = c(0.308258818045192, 
0.469694864370061, 0.360854910973552, 0.363683641698332, 0.43258619401693, 
0.2939858517149, 0.397951949316298, 0.235376278828237, 0.3685791718903, 
0.330295647415191, 0.212072592205125, 0.220703558050626, 0.389445269278106, 
0.286933037813081, 0.315659629884986, 0.30185119811882, 0.273429057319956, 
0.277357131556229, 0.339004410008943, 0.407114176119814), Negative = c(0.299363167088292, 
0.291804233603859, 0.267974798034839, 0.243143322044808, 0.201151951415105, 
0.239717129555608, 0.351629585705591, 0.258325790152011, 0.281660024058527, 
0.189920159505041, 0.265058882513953, 0.433664278547707, 0.114765460651494, 
0.402354633060689, 0.370370354887748, 0.3239536031819, 0.3279406609037, 
0.327198131828346, 0.298583999674218, 0.337902573718712), Neutral = c(0.392378014866516, 
0.23850090202608, 0.371170290991609, 0.39317303625686, 0.366261854567965, 
0.466297018729492, 0.250418464978111, 0.506297931019752, 0.349760804051173, 
0.479784193079769, 0.522868525280922, 0.345632163401667, 0.4957892700704, 
0.31071232912623, 0.313970015227266, 0.374195198699279, 0.398630281776344, 
0.395444736615424, 0.362411590316838, 0.254983250161474)), row.names = c(NA, 
20L), class = "data.frame")

1 个答案:

答案 0 :(得分:1)

#Vector of observed values
observed = df$class

#Remove first column from df so that we only have numeric values
temp = df[,-1]

#Obtain the predicted values based on column number
#of the maximum values in each row of temp
predicted = names(temp)[max.col(temp, ties.method = "first")]

#Create a union of the observed and predicted values
#so that all values are accounted for when we do 'table'
lvls = unique(c(observed, predicted))

#Convert observed and predicted values to factor
#with all levels that we created above
observed = factor(x = observed, levels = lvls)
predicted = factor(predicted, levels = lvls)

#Tabulate values
m = table(predicted, observed)

#Run confusionMatrix
library(caret)
confusionMatrix(m)
# Confusion Matrix and Statistics

          # observed
# predicted  Growth Neutral Negative
  # Growth        1       3        1
  # Neutral       3       5        4
  # Negative      2       0        1

# Overall Statistics

               # Accuracy : 0.35            
                 # 95% CI : (0.1539, 0.5922)
    # No Information Rate : 0.4             
    # P-Value [Acc > NIR] : 0.7500          

                  # Kappa : -0.0156         

 # Mcnemar's Test P-Value : 0.2276          

# Statistics by Class:

                     # Class: Growth Class: Neutral Class: Negative
# Sensitivity                 0.1667         0.6250          0.1667
# Specificity                 0.7143         0.4167          0.8571
# Pos Pred Value              0.2000         0.4167          0.3333
# Neg Pred Value              0.6667         0.6250          0.7059
# Prevalence                  0.3000         0.4000          0.3000
# Detection Rate              0.0500         0.2500          0.0500
# Detection Prevalence        0.2500         0.6000          0.1500
# Balanced Accuracy           0.4405         0.5208          0.5119