如何为每一行获取rpart
模型的终端节点的ID(或名称)? predict.rpart
只能为分类树返回预测的类(数字或因子)或类概率或某种组合(使用type="matrix"
)。
我想做点什么:
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit) # there are 5 terminal nodes
predict(fit, type = "node_id") # should return IDs of terminal nodes (e.g. 1-5) (does not work)
答案 0 :(得分:6)
partykit
包支持predict(..., type = "node")
,包括样本内外。您只需转换rpart
对象即可使用它:
library("partykit")
predict(as.party(fit), type = "node")
## 9 7 9 9 3 3 3 3 3 8 8 3 9 5 3 3 3 7 3 5 3 9 8 9 9 5 9 8 3 3 3 7 7 3 7 3 5 ## 9 5 8
## 9 7 9 9 3 3 3 3 3 8 8 3 9 5 3 3 3 7 3 5 3 9 8 9 9 5 9 8 3 3 3 7 7 3 7 3 5 ## 9 5 8
## 9 5 9 9 3 7 3 7 9 7 8 3 9 3 3 3 5 9 5 8 9 9 9 3 3 5 3 7 5 3 7 7 3 7 3 3 7 ## 5 7 9
## 9 5 9 9 3 7 3 7 9 7 8 3 9 3 3 3 5 9 5 8 9 9 9 3 3 5 3 7 5 3 7 7 3 7 3 3 7 ## 5 7 9
## 5
## 5
table(predict(as.party(fit), type = "node"))
## 3 5 7 8 9
## 29 12 14 7 19
答案 1 :(得分:5)
对于该模型,有4个分裂,产生5个“终端节点”或在rpart中使用的术语:<leaf>
s。我不明白为什么应该有5个预测。预测适用于特定情况,叶子是用于进行这些预测的可变数量的分割的结果。最终在叶子中的原始数据集中的行数可能是您想要的,在这种情况下,这些是获取这些数字的方法:
# Row-wise predicted class
fit$where
# counts of cases in leaves of prediction rules
table(fit$where)
3 5 7 8 9
29 12 14 7 19
为了组合适用于特定叶子的labels(fit)
,您需要遍历规则树并累积应用于生成特定叶子的所有分割的所有标签。你可能想看看:
?print.rpart
?rpart.object
?text.rpart
?labels.rpart
答案 2 :(得分:1)
上面的方法使用$ where只弹出树框架中的行号。因此,在使用kyphosis$ID = fit$where
时,可能会为某些观察点分配节点ID而不是叶节点ID
要获取实际的叶节点ID,请使用以下命令:
MyID <- row.names(fit$frame)
kyphosis$ID <- MyID[fit$where]