我用rpart做了一个回归树,根据一些变量评估了老年人的行走情况。通过使用该图,我想将输出用于其他软件中的进一步分析。但是,我想知道是否不仅可以从叶节点推导出每组的步行,而且是否可以从叶节点推导标准差(就行走而言)?
#### Decision tree with rpart
modelRT <- rpart(logwalkin~.-walkinmin-walkingtime, data=trainDF,
control=rpart.control(minsplit=25, maxdepth = 8, cp =0.00005))
rpart.plot(modelRT,type=3,digits=3,fallen.leaves=TRUE)
答案 0 :(得分:0)
我认为您不能从图中完成,但是您当然可以从rpart模型中得出每个叶节点的标准差。由于您不提供数据,因此我将使用内置的虹膜数据来举例说明。由于您对回归感兴趣,因此我将消除类变量(Species),并从其他变量中预测变量Sepal.Length。
设置
library(rpart)
library(rpart.plot)
RP = rpart(Sepal.Length ~ ., data=iris[,-5])
rpart.plot(as.party(RP))
如您所见,节点4,5,6,10,11,12和13是叶节点。返回的结构RP$where
的一部分告诉您原始实例去了哪片叶子。因此,您只需要使用此变量进行汇总。
SD = aggregate(iris$Sepal.Length, list(RP$where), sd)
SD
Group.1 x
1 4 0.2390221
2 5 0.2888391
3 6 0.2500526
4 10 0.4039577
5 11 0.3802046
6 12 0.3020486
7 13 0.2279132
Group.1告诉您哪个叶子节点,x告诉您在该叶子中结束的点的标准偏差。如果您希望将标准偏差添加到绘图中,则可以使用mtext
进行。摆弄摆放位置后:
rpart.plot(RP)
mtext(text=round(SD$x,1), side=1, line=3.2, at=seq(0.06,1,0.1505))
答案 1 :(得分:0)
要在树的每个节点上绘制标准偏差,可以将rpart.plot
与
node.fun
,如第6章中所述
rpart.plot package vignette。
例如
library(rpart.plot)
data(iris)
tree = rpart(Sepal.Length~., data=iris, cp=.05) # example tree
# Calculate the standard deviation at each node of the tree.
sd <- sqrt(tree$frame$dev / (tree$frame$n-1))
# Append the standard deviation as an extra column to the tree frame.
tree$frame <- cbind(tree$frame, sd)
# Create a node.fun to print the standard deviation at each node.
# See Chapter 6 of the rpart.plot vignette http://www.milbo.org/doc/prp.pdf.
node.fun.sd <- function(x, labs, digits, varlen)
{
s <- round(x$frame$sd, 2) # round sd to 2 digits
paste(labs, "\n\nsd", s)
}
# Plot the tree, using the node.fun to add the standard deviation to each node
rpart.plot(tree, type=4, node.fun=node.fun.sd)
给出
如果您只想在叶子节点上使用标准偏差(而不是 内部节点),您可以执行以下操作:
library(rpart.plot)
data(iris)
tree = rpart(Sepal.Length~., data=iris, cp=.05)
sd <- sqrt(tree$frame$dev / (tree$frame$n-1))
is.leaf <- tree$frame$var == "<leaf>" # logical vec, indexed on row in frame
sd[!is.leaf] <- NA # change sd of non-leaf nodes to NA
tree$frame <- cbind(tree$frame, sd)
node.fun2 <- function(x, labs, digits, varlen)
{
s <- paste("\n\nsd", round(x$frame$sd, 2)) # round sd to 2 digits
s[is.na(x$frame$sd)] <- "" # delete NAs
paste(labs, s)
}
rpart.plot(tree, type=4, node.fun=node.fun2)
给出