我是R(和rpart)的新手。我有车型数据(~400个型号)。我使用rpart将这些组合成一个较小的数字(例如5-10组),具有类似的车辆维修费用。我已成功运行rpart并拥有这些分组。
fit <- rpart(repairs ~ model, data=data, method='anova', control=rpart.control(minsplit=2,minbucket=1,cp=.0005))
假设每个终端节点中大约有40-80个型号。有没有办法让我创建一个引用终端节点中的值的公式。假设数据$ model包含所有模型名称(并且是我试图做的独立变量:
data$modelgroup <- data$model
data$modelgroup[data$modelgroup %in% terminal node 1] <- 'Group1'
data$modelgroup[data$modelgroup %in% terminal node 2] <- 'Group2'
and so on for the rest of the groups
此外,如果有一种方法可以做到这一点,而不必为每个组都有一行代码,那就不错了。
我知道我可以打印树并手动从终端节点复制文本并以这种方式完成,但效率非常低。
预先感谢您的协助!
根据下面的请求,我在下面添加了一个可重复的示例。
data <- read.csv("rpart_example.csv")
data
data[,1:2]
Model Amount
1 a 1
2 a 1
3 a 1
4 b 1
5 b 1
6 b 1
7 c 2
8 c 2
9 c 2
10 d 2
11 d 2
12 d 2
13 e 3
14 e 3
15 e 3
16 f 4
17 f 4
18 f 4
fit <- rpart(Amount ~ Model, data=data, method='anova',
control=rpart.control(minsplit=2,minbucket=1,cp=.0005))
print(fit)
n= 18
node), split, n, deviance, yval
* denotes terminal node
1) root 18 20.5 2.166667
2) Model=a,b,c,d 12 3.0 1.500000
4) Model=a,b 6 0.0 1.000000 *
5) Model=c,d 6 0.0 2.000000 *
3) Model=e,f 6 1.5 3.500000
6) Model=e 3 0.0 3.000000 *
7) Model=f 3 0.0 4.000000 *
# create a variable modelgroup that groups models per terminal nodes from rpart
# I can do this manually as below
# is there a way for me to automate this assignment?
data$modelgroup <- as.character(data$Model)
# per rpart output, a&b are grouped into one terminal node
data$modelgroup[data$modelgroup %in% c('a','b')] <- 'Group1'
# per rpart output, c&d are grouped into the second terminal node
data$modelgroup[data$modelgroup %in% c('c','d')] <- 'Group2'
# per rpart, e is the third terminal node
data$modelgroup[data$modelgroup == 'e'] <- 'Group3'
# per rpart, f is the fourth terminal node
data$modelgroup[data$modelgroup == 'f'] <- 'Group4'
答案 0 :(得分:2)
在rpart
个对象中,您要查找的信息基本上可以存储在$where
元素中。它为您提供了每个观测值分配的节点号:
table(fit$where, data$modelgroup)
## Group1 Group2 Group3 Group4
## 3 6 0 0 0
## 4 0 6 0 0
## 6 0 0 3 0
## 7 0 0 0 3
当然,您也可以将节点ID(3,4,6,7)切换为因子或字符变量,例如factor(fit$where, levels = c(3, 4, 6, 7), labels = paste0("Group", 1:4))
或沿着这些行的某些内容。
如果您想通过简单统一的界面对新数据执行此操作,则可以将rpart
对象转换为包party
中的partykit
对象:
library("partykit")
fit2 <- as.party(fit)
print(fit2)
和plot(fit2)
的统一方法以及不同类型的predict(fit2, ...)
可用:
table(predict(fit2, newdata = data, type = "node"), data$modelgroup)
## Group1 Group2 Group3 Group4
## 3 6 0 0 0
## 4 0 6 0 0
## 6 0 0 3 0
## 7 0 0 0 3
这会返回与上面相同的结果,但也可以轻松应用于其他newdata
。