我想在树上添加一些信息。比方说,我有一个像这样的数据库:
library(rpart)
library(rpart.plot)
set.seed(1)
mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T),
var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T))
我可以跑一棵树:
mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0)
pfit<- prune(mytree, cp=mytree$cptable[4,"CP"])
prp(pfit,type=1,extra=100,fallen.leaves=F,shadow.col="darkgray",box.col=rgb(0.8,0.9,0.8))
对我来说没关系,但我想我想知道每片叶子的平均暴露量。
我知道我可以为prp添加一些信息,例如每个叶子的重量都有一个函数:
node.fun1 <- function(x, labs, digits, varlen)
{
paste("Weight \n",x$frame$wt)
}
prp(pfit,type=1,extra=100,fallen.leaves=F,shadow.col="darkgray",box.col=rgb(0.8,0.9,0.8),node.fun = node.fun1)
但它只有在帧中计算时才有效,即rpart函数的结果。
如何向绘图添加自定义信息,例如平均曝光,或计算自定义指标的任何其他函数,并将其添加到表frame
?
答案 0 :(得分:1)
这真的很好,我不知道这是一个选择。
所有工作似乎都是获取每个节点上使用的原始数据的子集。这对于终端节点来说很容易,但我没有找到一种直接的方法来识别每个节点中使用的数据行,而不仅仅是叶子。如果有人知道一种更简单的方式,我很乐意听到它。
library('rpart.plot')
set.seed(1)
mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T),
var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T))
mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0)
pfit<- prune(mytree, cp=mytree$cptable[4,"CP"])
rpart.plot(pfit)
定义你的新函数,它取x
,拟合rpart
的结果(我没有查看其他参数,但小插图应该有用)。
对于x$frame
的每一行,我们需要获取用于计算汇总统计数据的数据。不幸的是,x$where
只告诉我们每个观察所在的终端节点。因此,对于每个节点编号,我们使用subset.rpart
来获取基础数据,并随意执行任何操作
f <- function(x, labs, digits, varlen) {
nodes <- as.integer(rownames(x$frame))
z <- sapply(nodes, function(y) {
data <- subset.rpart(x, y)
c(mean = mean(data$expo), nrow(data), nrow(data) / length(x$where) * 100)
})
sprintf('Mean expo: %.2f\nn=%.0f (%.0f%%)', z[1, ], z[2, ], z[3, ])
}
prp(pfit, type=1, extra=100, fallen.leaves=FALSE,
shadow.col="darkgray", box.col=rgb(0.8,0.9,0.8),
node.fun = f)
工作由subset.rpart
完成,它获取节点编号并返回节点上使用的data
子集。
subset.rpart <- function(tree, node = 1L) {
## returns subset of tree$call$data used on any node
data <- eval(tree$call$data, parent.frame(1L))
wh <- sapply(as.integer(rownames(tree$frame)), parent)
wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]
}
parent <- function(x) {
## returns vector of parent nodes
if (x[1] != 1)
c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
}
测试
## tests
dim(subset.rpart(pfit, 1)) == dim(mydb)
# [1] TRUE TRUE
## terminal nodes
nodes <- as.integer(rownames(pfit$frame[pfit$frame$var %in% '<leaf>', ]))
sum(sapply(nodes, function(x) nrow(subset.rpart(pfit, x)))) == nrow(mydb)
# [1] TRUE
答案 1 :(得分:0)
我不知道这是否正是您想要的,但是请尝试使用“ sparkline”和“ visNetwork”软件包。它们与rpart对象一起使用