rpart:如何获得"其中"验证数据集的向量?

时间:2018-03-09 17:32:26

标签: r decision-tree rpart

当与rpart匹配时,它返回"其中"向量,告诉我们将训练数据集中的每条记录留在树上。是否有一个功能可以返回与此相似的内容"其中"测试数据集的向量?

2 个答案:

答案 0 :(得分:0)

我认为partykit包可以满足您的需求

library('rpart')
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
fit
rpart.plot::rpart.plot(fit)

enter image description here

检查相同的数据

set.seed(1)
idx <- sample(nrow(kyphosis), 5L)
fit$where[idx]
# 22 30 46 71 16 
#  9  3  7  7  3 

library('partykit')
fit <- as.party(fit)
predict(fit, kyphosis[idx, ], type = 'node')
# 22 30 46 71 16 
#  9  3  7  7  3 

检查新数据

dd <- kyphosis[idx, ]
set.seed(1)
dd[] <- lapply(dd, sample)
predict(fit, dd, type = 'node')
# 22 30 46 71 16 
#  5  3  7  9  3 

## so #46 should meet criteria for the 7th leaf:
with(kyphosis[46, ],
     Start  >= 8.5  &  # node 1
       Start < 14.5 &  # node 2
       Age  >= 55   &  # node 4
       Age  >= 111     # node 6
)
# [1] TRUE

答案 1 :(得分:0)

正如您提到的,predict.rpart包中的函数rpart 没有where选项(以显示关联的叶节点号 预测)。 但是,rpart.predict包中的rpart.plot函数 将做到这一点。例如

> library(rpart.plot)
> fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
> rpart.predict(fit, newdata=kyphosis[1:3,], nn=TRUE)

给出(请注意节点号nn列):

   absent present nn
1 0.42105 0.57895  3
2 0.85714 0.14286 22
3 0.42105 0.57895  3

> rpart.predict(fit, newdata=kyphosis[1:3,], nn=TRUE)$nn

仅提供where节点号:

[1]  3 22  3

显示每种预测使用的规则

> rpart.predict(fit, newdata=kyphosis[1:5,], rules=TRUE)

给出

   absent present
1 0.42105 0.57895 because Start <  9
2 0.85714 0.14286 because Start is 9 to 15 & Age >= 111
3 0.42105 0.57895 because Start <  9