袋装树的平均预测概率

时间:2019-11-19 18:32:09

标签: r rstudio

使用来自rattle.data包的天气数据,我尝试过编写一个用于袋装树分类的脚本,其后的RainTomorrow是目标列

if(!require(rpart)) install.packages("rpart") 
if(!require(rpart.plot)) install.packages("rpart.plot") 
if(!require(caret)) install.packages("caret") 
if(!require(rattle.data)) install.packages("rattle.data") 
if(!require(tidyverse)) install.packages("tidyverse") 
if(!require(ipred)) install.packages("ipred") 
if(!require(Metrics)) install.packages("Metrics") 
library(rpart)
library(rpart.plot)
library(rattle.data)
library(tidyverse)
library(caret)
library(ipred)
library(Metrics)

set.seed(500)

data <- weather

# cleaning data
data <-
  data %>%
  mutate(month = months(Date)) %>% 
  select(-Date, -Location, -RISK_MM) %>% 
  mutate(RainTomorrow = as.factor(ifelse(RainTomorrow == "No", 0, 1))) %>% 
  na.omit()

# creating train and test data
index <- createDataPartition(data$RainTomorrow, p = .6, list = FALSE)
train_data <- data[ index, ]
test_data <- data[-index, ]

# creating models

bagged_tree <- bagging(formula = RainTomorrow ~ ., 
                        data = train_data,
                        coob = TRUE)

pred_bagg_class <- predict(object = bagged_tree ,    
                            newdata = test_data,  
                            type = "class") 

# predictions on the test set
pred_bagg <- predict(object = bagged_tree,
                newdata = test_data,
                type = "prob")

现在我需要的是平均所有预测的概率,然后选择具有最大概率的类,但是我总是得到0.5(如果我在pred_bagg上运行mean()或rowMeans()函数) ),这显然是不正确的,我是否错过了重要的事情?

1 个答案:

答案 0 :(得分:1)

因此,如果您需要查找所有预测值的均值,则可能需要这样做:

df <- as.data.frame(as.numeric(pred_bagg_class) - 1)
df <- cbind(df, pred_bagg)
df$pred_mean <- rowMeans(df)

哪个会给你:

 df
    as.numeric(pred_bagg_class) - 1    0    1 pred_mean
1                                 0 0.76 0.24 0.3333333
2                                 0 0.72 0.28 0.3333333
3                                 0 1.00 0.00 0.3333333
4                                 0 1.00 0.00 0.3333333
5                                 0 0.96 0.04 0.3333333
6                                 0 0.96 0.04 0.3333333
7                                 1 0.28 0.72 0.6666667
8                                 0 0.76 0.24 0.3333333
9                                 0 0.56 0.44 0.3333333
10                                0 0.84 0.16 0.3333333
11                                1 0.24 0.76 0.6666667

但是,如果您在rowMeans上使用pred_bagg,则总是会得到0.5,因为pred_bagg对于每种目标变量都有各自的概率,每行目标加1,如果您取一个平均值,每次将为您提供0.5。