将规则导出到数据框并链接规则以训练数据

时间:2018-01-28 18:04:15

标签: r decision-tree rpart

我用rpart训练了一些数据并且有兴趣用树终端节点标记每个观察, 并链接到与该终端节点对应的规则。

我使用以下代码作为示例:

library(rpart)
library(rattle)
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
table(fit$where)
rattle::asRules(fit)

我可以通过适合$ where来标记每个观察, 标签是:

> table(fit$where)

 3  5  7  8  9 
29 12 14  7 19 

第一个问题:这些标签与rattle :: asRules(fit)生成的标签不一致,即3,23,22,10,4 如何在两者之间生成映射表?

第二个问题:asRules只是打印,而我想将规则放在表中而不是标准输出。

我的预期结果:一个数据框,其中包含fit $ where和asRules标签之间的映射,另一列带有规则文本作为字符串,例如:

 Rule number: 4 [Kyphosis=absent cover=29 (36%) prob=0.00]
   Start>=8.5
   Start>=14.5

如果我们可以在单独的列中解析文本到ID,统计和条件,甚至更好但不是强制性的。

我找到了许多相关问题和链接,但没有找到最终答案。

非常感谢, Kamashay

进度更新29/01

如果我有规则ID,我可以通过path.rpart单独提取每个规则:

>path.rpart(fit,node=22) 

 node number: 22 
   root
   Start>=8.5
   Start< 14.5
   Age>=55
   Age>=111

这使得规则成为我可以转换为字符串的列表。 但是这些ID是'asRules'功能而不是'适合$ where'...

使用“partykit”会得到与“fit $ where”相同的结果:

library("partykit")
> table(predict(as.party(fit), type = "node"))

 3  5  7  8  9 
29 12 14  7 19 

所以,我仍然无法在两者之间建立链接(asRules ID并且适合$ where ID), 我可能遗漏了一些基本的东西,或者有一种更简单的方法来完成任务。

你能帮忙吗?

4 个答案:

答案 0 :(得分:3)

您可以使用

找到与每个拟合$对应的规则编号(实际上是叶节点编号)
> row.names(fit$frame)[fit$where]
 [1] "3"  "22" "3"  "3"  "4"  "4"  ...

您可以通过

更接近所需的输出
> rattle::asRules(fit, TRUE)
R  3 [23%,0.58] Start< 8.5
R 23 [ 9%,0.57] Start>=8.5 Start< 14.5 Age>=55 Age< 111
...

答案 1 :(得分:3)

你的意思是这样吗?

library(rpart)
library(rpart.utils)
library(dplyr)

#model
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)

#dataframe having leaf node's rule and subrule combination
rule_df <- rpart.rules.table(fit) %>%
  filter(Leaf==TRUE) %>%
  group_by(Rule) %>%
  summarise(Subrules = paste(Subrule, collapse=","))

#final dataframe
df <- kyphosis %>%
  mutate(Rule = row.names(fit$frame)[fit$where]) %>%
  left_join(rule_df, by="Rule")
head(df)

#subrule table
rpart.subrules.table(fit)

输出为:

  Kyphosis Age Number Start Rule    Subrules
1   absent  71      3     5    3          R1
2   absent 158      3    14   22 L1,R2,R3,L4
3  present 128      4     5    3          R1
4   absent   2      5     1    3          R1
5   absent   1      4    15    4       L1,L2
6   absent   1      2    16    4       L1,L2

子规则定义:

  Subrule Variable Value Less Greater
1      L1    Start   8.5 <NA>     8.5
2      L2    Start  14.5 <NA>    14.5
3      L3      Age  <NA>   55    <NA>
4      L4      Age   111 <NA>     111
5      R1    Start  <NA>  8.5    <NA>
6      R2    Start  <NA> 14.5    <NA>
7      R3      Age    55 <NA>      55
8      R4      Age  <NA>  111    <NA>

答案 2 :(得分:1)

您可以通过以下方式获得规则(叶)的数量:

nrules <- as.integer(rownames(fit$frame[fit$frame$var == "<leaf>",]))

您还可以迭代如下规则:

rules <- lapply(nrules, path.rpart, tree=fit, pretty=0, print.it=FALSE)

另一种替代方法是使用软件包rpart.plot

rules <- rpart.plot::rpart.rules(model, cover=T, nn=T)

答案 3 :(得分:0)

这是值得的,毕竟这是我使用的:

[1]用于在fit $ where和asRules之间对齐标签我使用了@Graham Williams的解决方案, 或者通过采用@VitoshKa中的函数来获得正确的标签:https://stackoverflow.com/a/30088268/8263160

[2]用于在数据框中创建格式良好的规则列表我采用并修改了TomášGreif的parse_tree函数:  https://www.r-bloggers.com/create-sql-rules-from-rpart-model/