使用rpart生成Sankey图的决策树

时间:2018-09-06 10:42:08

标签: r plotly decision-tree rpart sankey-diagram

我可以使用属于基础R的Kyphosis数据集使用Rpart创建树:

fit <- rpart(Kyphosis ~ Age + Number + Start,
         method="class", data=kyphosis)
printcp(fit)
plot(fit, uniform=TRUE,main="Classification Tree for Kyphosis")
text(fit, use.n=TRUE, all=TRUE, cex=.8)

这是树的样子: enter image description here

现在为了更好地可视化树,我想使用plotly使用sankey图。要在图中绘制一个Sankey图,必须执行以下操作:

library(plotly)
nodes=c("Start>=8.5","Start>-14.5","absent",
                   "Age<55","absent","Age>=111","absent","present","present")
p <- plot_ly(
  type = "sankey",
  orientation = "h",      
  node = list(
    label = nodes,
    pad = 10,
    thickness = 20,
    line = list(
      color = "black",
      width = 0.5
    )
  ),

  link = list(
    source = c(0,1,1,3,3,5,5,0),
    target = c(1,2,3,4,5,6,7,8),
    value =  c(1,1,1,1,1,1,1,1)
  )
) %>% 
  layout(
    title = "Desicion Tree",
    font = list(
      size = 10
    )
  )
p

这将创建与树对应的sankey图(硬编码)。所需的三个必要向量是“源”,“目标”,“值”,其外观如下:

硬编码的sankey图:

enter image description here

我的问题是使用rpart对象“ fit”,我似乎无法轻松地获得矢量来产生所需的用于绘制的“源”,“目标”和“值”矢量。

fit $ frame和fit $ splits包含一些信息,但是很难将它们汇总或一起使用。在fit对象上使用打印功能会生成所需的信息,但我不想进行文本编辑来获取它。

print(fit)

输出:

1) root 81 17 absent (0.79012346 0.20987654)  
   2) Start>=8.5 62  6 absent (0.90322581 0.09677419)  
     4) Start>=14.5 29  0 absent (1.00000000 0.00000000) *
     5) Start< 14.5 33  6 absent (0.81818182 0.18181818)  
      10) Age< 55 12  0 absent (1.00000000 0.00000000) *
      11) Age>=55 21  6 absent (0.71428571 0.28571429)  
        22) Age>=111 14  2 absent (0.85714286 0.14285714) *
        23) Age< 111 7  3 present (0.42857143 0.57142857) *
   3) Start< 8.5 19  8 present (0.42105263 0.57894737) *

那么,有没有一种简单的方法可以使用rpart对象获得这三个向量以作图生成Sankey图?此图将在Web应用程序中使用,因此必须使用图,因为我们已经具有与之对应的JavaScript,并且必须易于重用才能应用于各种数据集。

1 个答案:

答案 0 :(得分:2)

这是我的尝试:

从我看来,挑战在于生成nodessource变量。

样本数据:

fit <- rpart(Kyphosis ~ Age + Number + Start,
             method="class", data=kyphosis)

生成nodes

frame <- fit$frame
isLeave <- frame$var == "<leaf>"
nodes <- rep(NA, length(isLeave))
ylevel <- attr(fit, "ylevels")
nodes[isLeave] <- ylevel[frame$yval][isLeave]
nodes[!isLeave] <- labels(fit)[-1][!isLeave[-length(isLeave)]]

生成source

node <- as.numeric(row.names(frame))
depth <- rpart:::tree.depth(node)
source <- depth[-1] - 1

reps <- rle(source)
tobeAdded <- reps$values[sapply(reps$values, function(val) sum(val >= which(reps$lengths > 1))) > 0]
update <- source %in% tobeAdded
source[update] <- source[update] + sapply(tobeAdded, function(tobeAdd) rep(sum(which(reps$lengths > 1) <= tobeAdd), 2))

经过以下测试:

library(rpart)
fit <- rpart(Kyphosis ~ Age + Number + Start,
             method="class", data=kyphosis)
fit2 <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis,
              parms = list(prior = c(.65,.35), split = "information"))

如何到达那里:

请参阅:getS3method("print", "rpart")