如何绘制经过训练的rpart决策树模型的变量重要性?

时间:2019-05-25 12:09:36

标签: r ggplot2 plot rpart

我使用rpart训练了一个模型,我想生成一个显示用于决策树的变量的变量重要性的图,但是我不知道怎么做。

我能够提取变量重要性。我尝试了ggplot,但是没有显示任何信息。我尝试在上面使用plot()函数,但它只给我一个平面图。我还尝试了plot.default,这要好一些,但现在仍然是我想要的。

这里的rpart模型训练:

argIDCART = rpart(Argument ~ ., 
                  data = trainSparse, 
                  method = "class")

将变量重要性放入数据框中。

argPlot <- as.data.frame(argIDCART$variable.importance)

以下是打印内容的一部分:

       argIDCART$variable.importance
noth                             23.339346
humanitarian                     16.584430
council                          13.140252
law                              11.347241
presid                           11.231916
treati                            9.945111
support                           8.670958

我想绘制一个图表,显示变量/特征名称及其数值重要性。我只是无法做到这一点。它似乎只有一列。我尝试使用单独的功能将它们分开,但也不能这样做。

ggplot(argPlot, aes(x = "variable importance", y = "feature"))

只打印空白。

其他情节看起来真的很糟糕。

plot.default(argPlot)

看起来像绘制点,但没有放置变量名称。

2 个答案:

答案 0 :(得分:0)

如果要查看变量名称,最好将它们用作x轴上的标签。

plot(argIDCART$variable.importance, xlab="variable", 
    ylab="Importance", xaxt = "n", pch=20)
axis(1, at=1:7, labels=row.names(argIDCART))

Variable Importance

(您可能需要调整窗口大小以正确查看标签。)

如果您有很多变量,则可能需要旋转变量名,以使变量名不重叠。

par(mar=c(7,4,3,2))
plot(argIDCART$variable.importance, xlab="variable", 
    ylab="Importance", xaxt = "n", pch=20)
axis(1, at=1:7, labels=row.names(argIDCART), las=2)

Rotated axis labels

数据

argIDCART = read.table(text="variable.importance
noth                             23.339346
humanitarian                     16.584430
council                          13.140252
law                              11.347241
presid                           11.231916
treati                            9.945111
support                           8.670958", 
header=TRUE)

答案 1 :(得分:0)

由于没有可用的可复制示例,我使用ggplot2软件包和其他用于数据处理的软件包,基于自己的R数据集安装了响应。

library(rpart)
library(tidyverse)
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
df <- data.frame(imp = fit$variable.importance)
df2 <- df %>% 
  tibble::rownames_to_column() %>% 
  dplyr::rename("variable" = rowname) %>% 
  dplyr::arrange(imp) %>%
  dplyr::mutate(variable = forcats::fct_inorder(variable))
ggplot2::ggplot(df2) +
  geom_col(aes(x = variable, y = imp),
           col = "black", show.legend = F) +
  coord_flip() +
  scale_fill_grey() +
  theme_bw()

enter image description here

ggplot2::ggplot(df2) +
  geom_segment(aes(x = variable, y = 0, xend = variable, yend = imp), 
               size = 1.5, alpha = 0.7) +
  geom_point(aes(x = variable, y = imp, col = variable), 
             size = 4, show.legend = F) +
  coord_flip() +
  theme_bw()

enter image description here