我已经训练了一个partykit包ctree分类决策树,我需要计算子树的分类概率(不仅仅是叶子节点)。 因此,例如,如果子树由3个具有以下概率的叶节点组成: 叶1(120观察):0.45 叶2(160观察):0.49 叶3(190观察):0.83
对于这个假设的子树,加权平均概率是 120 * 0.42 + 160 * 0.49 + 190 * 0.83 /(120 + 160 + 190)= 0.507
等等我需要遍历ctree对象并递归计算每个节点的所有加权概率。
我有这段代码:
data(airquality)
airq <- subset(airquality, !is.na(Ozone))
airct <- ctree(Ozone ~ ., data = airq,
controls = ctree_control(maxsurrogate = 3))
traverse <- function(treenode){
if(treenode$terminal){
bas=paste("Current node is terminal node with",treenode$nodeID,'prediction',treenode$prediction)
print(bas)
return(0)
} else {
bas=paste("Current node",treenode$nodeID,"Split var. ID:",treenode$psplit$variableName,"split value:",treenode$psplit$splitpoint,'prediction',treenode$prediction)
print(bas)
}
traverse(treenode$left)
traverse(treenode$right)
}
遍历树的对partykit对象不起作用。 另一方面,我有这个代码,它只列出叶节点的所有可能性:
preds.ls <- list(predict(airct , type = "prob"))[1]
pred.probs.df <- unique(as.data.frame((preds.ls[[1]])))
任何建议将这2个片段组合到将遍历PARTYKIT对象并计算此加权平均值的代码,我们不胜感激
答案 0 :(得分:0)
我不熟悉partykit
但是这个简单的函数遍历ctree
并提取每个内部和终端节点的概率:
library(party)
set.seed(100)
dt <- ctree(factor(mpg > 20)~., data = mtcars,
control = ctree_control(minsplit=2, minbucket=1, mincriterion=0))
traverse <- function(node) {
if (node$terminal) {
return(node$prediction[2])
}
return(c(node$prediction[2],
traverse(node$left), traverse(node$right)))
}
调用该函数会产生以下概率向量:
> traverse(dt@tree)
[1] 0.4375000 1.0000000 0.1428571 0.4285714 0.7500000 0.0000000 0.0000000
最左边的值是通过以下方式验证的人口值:
> mean(mtcars$mpg > 20)
[1] 0.4375
其余值将按从左到右的顺序排列。您可以看到1和0排列在预期的位置。