R partykit计算子stree上的分类概率

时间:2017-03-21 13:37:22

标签: r traversal decision-tree party

我已经训练了一个partykit包ctree分类决策树,我需要计算子树的分类概率(不仅仅是叶子节点)。 因此,例如,如果子树由3个具有以下概率的叶节点组成: 叶1(120观察):0.45 叶2(160观察):0.49 叶3(190观察):0.83

对于这个假设的子树,加权平均概率是 120 * 0.42 + 160 * 0.49 + 190 * 0.83 /(120 + 160 + 190)= 0.507

等等我需要遍历ctree对象并递归计算每个节点的所有加权概率。

我有这段代码:

data(airquality)
airq <- subset(airquality, !is.na(Ozone))
airct <- ctree(Ozone ~ ., data = airq,
                 controls = ctree_control(maxsurrogate = 3))
traverse <- function(treenode){
    if(treenode$terminal){
      bas=paste("Current node is terminal node with",treenode$nodeID,'prediction',treenode$prediction)
      print(bas)
      return(0)
    } else {
      bas=paste("Current node",treenode$nodeID,"Split var. ID:",treenode$psplit$variableName,"split value:",treenode$psplit$splitpoint,'prediction',treenode$prediction)
      print(bas)
    }
    traverse(treenode$left)
    traverse(treenode$right)
  }
遍历树的

对partykit对象不起作用。 另一方面,我有这个代码,它只列出叶节点的所有可能性:

preds.ls <- list(predict(airct , type = "prob"))[1]
pred.probs.df <- unique(as.data.frame((preds.ls[[1]])))

任何建议将这2个片段组合到将遍历PARTYKIT对象并计算此加权平均值的代码,我们不胜感激

1 个答案:

答案 0 :(得分:0)

我不熟悉partykit但是这个简单的函数遍历ctree并提取每个内部和终端节点的概率:

   library(party)

    set.seed(100)
    dt <- ctree(factor(mpg > 20)~., data = mtcars,
                control = ctree_control(minsplit=2, minbucket=1, mincriterion=0))

    traverse <- function(node) {
      if (node$terminal) {
        return(node$prediction[2])
      }
      return(c(node$prediction[2],
               traverse(node$left), traverse(node$right)))
    }

enter image description here

调用该函数会产生以下概率向量:

> traverse(dt@tree)
[1] 0.4375000 1.0000000 0.1428571 0.4285714 0.7500000 0.0000000 0.0000000

最左边的值是通过以下方式验证的人口值:

> mean(mtcars$mpg > 20)
[1] 0.4375

其余值将按从左到右的顺序排列。您可以看到1和0排列在预期的位置。