GBM R函数:为每个类分别获取变量重要性

时间:2015-04-14 20:49:43

标签: r machine-learning classification data-mining gbm

我在R(gbm包)中使用gbm函数来拟合用于多类分类的随机梯度增强模型。我只是试图获得每个类的每个预测器单独的重要性,就像Hastie book (the Elements of Statistical Learning)(第382页)中的这张图片一样。

enter image description here

但是,函数summary.gbm仅返回预测变量的整体重要性(它们对所有类的平均重要性)。

有谁知道如何获得相对重要性值?

2 个答案:

答案 0 :(得分:11)

我认为简短的回答是,在第379页,Hastie提到他使用MART,这似乎只适用于Splus。

我同意gbm包似乎不允许看到单独的相对影响。如果你对mutliclass问题感兴趣,你可能会通过为每个类构建一个vs-all-gbm然后从每个模型中获取重要性度量来获得非常相似的东西。

所以说你的课程是a,b,c,& d。您对其他模型进行建模,并从该模型中获得重要性。然后你模拟b与其余的模型,并从该模型中获得重要性。等

答案 1 :(得分:2)

我深入研究了gbm软件包如何计算重要性,它基于ErrorReduction,该错误包含在结果的tree元素中,可以使用pretty.gbm.trees()进行访问。相对影响是通过对每个变量取所有树上的ErrorReduction的总和而获得的。对于多类问题,模型中实际上有n.trees*num.classes个树。因此,如果有3个类,则可以计算每三棵树上每个变量的ErrorReduction的总和,以得出一个类的重要性。我编写了以下函数来实现此目的,然后绘制结果:

按类别获取变量重要性

RelInf_ByClass <- function(object, n.trees, n.classes, Scale = TRUE){
  library(dplyr)
  library(purrr)
  library(gbm)
  Ext_ErrRed<- function(ptree){
    ErrRed <- ptree %>% filter(SplitVar != -1) %>% group_by(SplitVar) %>% 
      summarise(Sum_ErrRed = sum(ErrorReduction))
  }
  trees_ErrRed <- map(1:n.trees, ~pretty.gbm.tree(object, .)) %>% 
    map(Ext_ErrRed)

  trees_by_class <- split(trees_ErrRed, rep(1:n.classes, n.trees/n.classes)) %>% 
    map(~bind_rows(.) %>% group_by(SplitVar) %>% 
          summarise(rel_inf = sum(Sum_ErrRed)))
  varnames <- data.frame(Num = 0:(length(object$var.names)-1),
                         Name = object$var.names)
  classnames <- data.frame(Num = 1:object$num.classes, 
                           Name = object$classes)
  out <- trees_by_class %>% bind_rows(.id = "Class") %>%  
    mutate(Class = classnames$Name[match(Class,classnames$Num)],
    SplitVar = varnames$Name[match(SplitVar,varnames$Num)]) %>%
    group_by(Class) 
  if(Scale == FALSE){
    return(out)
    } else {
    out <- out %>% mutate(Scaled_inf = rel_inf/max(rel_inf)*100)
    }
}

按类别绘制变量重要性

在我的实际用途中,我有40多个特征,因此我可以选择指定要绘制的特征数量。如果我想针对每个类分别对图进行排序,我也不能使用分面,这就是为什么我使用gridExtra的原因。

plot_imp_byclass <- function(df, n) {
  library(ggplot2)
  library(gridExtra)
  plot_imp_class <- function(df){
    df %>% arrange(rel_inf) %>% 
      mutate(SplitVar = factor(SplitVar, levels = .$SplitVar)) %>% 
      ggplot(aes(SplitVar, rel_inf))+
      geom_segment(aes(x = SplitVar, 
                       xend = SplitVar, 
                       y = 0, 
                       yend = rel_inf))+
      geom_point(size=3, col = "cyan") + 
      coord_flip()+
      labs(title = df$Class[[1]], x = "Variable", y = "Importance")+
      theme_classic()+
      theme(plot.title = element_text(hjust = 0.5))
  }

  df %>% top_n(n, rel_inf) %>% split(.$Class) %>% 
    map(plot_imp_class) %>% map(ggplotGrob) %>% 
    {grid.arrange(grobs = .)}
}

尝试

gbm_iris <- gbm(Species~., data = iris)
imp_byclass <- RelInf_ByClass(gbm_iris, length(gbm_iris$trees), 
                              gbm_iris$num.classes, Scale = F)
plot_imp_byclass(imp_byclass, 4)

如果您对所有类的结果求和,似乎会提供与内置relative.influence函数相同的结果。

relative.influence(gbm_iris)
# n.trees not given. Using 100 trees.
# Sepal.Length  Sepal.Width Petal.Length  Petal.Width 
# 0.00000     51.88684   2226.88017    868.71085 

imp_byclass %>% group_by(SplitVar) %>% summarise(Overall_rel_inf = sum(rel_inf))
# A tibble: 3 x 2
# SplitVar     Overall_rel_inf
# <fct>                  <dbl>
#   1 Petal.Length          2227. 
# 2 Petal.Width            869. 
# 3 Sepal.Width             51.9