说我有
head(kyphosis)
inTrain <- sample(1:nrow(kyphosis), 45, replace = F)
TRAIN_KYPHOSIS <- kyphosis[inTrain,]
TEST_KYPHOSIS <- kyphosis[-inTrain,]
(kyph_tree <- rpart(Number ~ ., data = TRAIN_KYPHOSIS))
如何在TEST_KYPHOSIS
?
如何获取摘要,例如每个测试观察映射到的终端节点的偏差和预测值?
答案 0 :(得分:7)
rpart
实际上有这个功能,但它没有暴露(奇怪的是,它是一个相当明显的要求)。
predict_nodes <-
function (object, newdata, na.action = na.pass) {
where <-
if (missing(newdata))
object$where
else {
if (is.null(attr(newdata, "terms"))) {
Terms <- delete.response(object$terms)
newdata <- model.frame(Terms, newdata, na.action = na.action,
xlev = attr(object, "xlevels"))
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, newdata, TRUE)
}
rpart:::pred.rpart(object, rpart:::rpart.matrix(newdata))
}
as.integer(row.names(object$frame))[where]
}
然后:
> predict_nodes(kyph_tree, TEST_KYPHOSIS)
[1] 5 3 4 3 3 5 5 3 3 3 3 5 5 4 3 5 4 3 3 3 3 4 3 4 4 5 5 3 4 4 3 5 3 5 5 5
答案 1 :(得分:5)
一种选择是将rpart
对象从party
包转换为类partykit
的对象。这提供了处理递归聚会的通用工具包。转换很简单:
library("partykit")
(kyph_party <- as.party(kyph_tree))
Model formula:
Number ~ Kyphosis + Age + Start
Fitted party:
[1] root
| [2] Start >= 15.5: 2.933 (n = 15, err = 10.9)
| [3] Start < 15.5
| | [4] Age >= 112.5: 3.714 (n = 14, err = 18.9)
| | [5] Age < 112.5: 5.125 (n = 16, err = 29.8)
Number of inner nodes: 2
Number of terminal nodes: 3
(为了完全重现,请在运行我的代码之前使用set.seed(1)
运行问题中的代码。)
对于此类的对象,plot()
,predict()
,fitted()
等有一些更灵活的方法。例如,plot(kyph_party)
产生的信息比默认plot(kyph_tree)
。 fitted()
方法使用拟合的节点编号和观察到的对训练数据的响应来提取两列data.frame
。
kyph_fit <- fitted(kyph_party)
head(kyph_fit, 3)
(fitted) (response)
1 5 6
2 2 2
3 4 3
通过这种方式,您可以轻松计算出您感兴趣的任何数量,例如每个节点内的平方,中位数或残差平方和。
tapply(kyph_fit[,2], kyph_fit[,1], mean)
2 4 5
2.933333 3.714286 5.125000
tapply(kyph_fit[,2], kyph_fit[,1], median)
2 4 5
3 4 5
tapply(kyph_fit[,2], kyph_fit[,1], function(x) sum((x - mean(x))^2))
2 4 5
10.93333 18.85714 29.75000
您可以使用您选择的任何其他函数来计算分组统计信息表,而不是简单的tapply()
。
现在要了解从测试数据TEST_KYPHOSIS
到树中哪个节点的哪个观察点,您只需使用predict(..., type = "node")
方法:
kyph_pred <- predict(kyph_party, newdata = TEST_KYPHOSIS, type = "node")
head(kyph_pred)
2 3 4 6 7 10
4 4 5 2 2 5