我需要从决策树中的规则中提取信息。我在R中使用rpart包。我在包中使用演示数据来解释我的要求:
data(stagec)
fit<- rpart(formula = pgstat ~ age + eet + g2 + grade + gleason + ploidy, data = stagec, method = "class", control=rpart.control(cp=0.05))
fit
打印适合的节目
n= 146
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 146 54 0 (0.6301370 0.3698630)
2) grade< 2.5 61 9 0 (0.8524590 0.1475410) *
3) grade>=2.5 85 40 1 (0.4705882 0.5294118)
6) g2< 13.2 40 17 0 (0.5750000 0.4250000)
12) ploidy=diploid,tetraploid 31 11 0 (0.6451613 0.3548387) *
13) ploidy=aneuploid 9 3 1 (0.3333333 0.6666667) *
7) g2>=13.2 45 17 1 (0.3777778 0.6222222)
14) g2>=17.91 22 8 0 (0.6363636 0.3636364) *
15) g2< 17.91 23 3 1 (0.1304348 0.8695652) *
e.g。我想获得第12个节点的信息,如下所示
如果等级&gt; = 2.5且g2&lt; 13.2和倍性(二倍体,四倍体)然后0级被预测,置信度为65%。任何关于此的指示都会非常有用。
由于
答案 0 :(得分:1)
rpart.plot
软件包版本3.0(2018年7月)具有功能
rpart.rules
用于为树生成一组规则。例如
library(rpart.plot)
data(stagec)
fit <- rpart(formula = pgstat ~ ., data = stagec, method = "class", control=rpart.control(cp=0.05))
rpart.rules(fit)
给予
pgstat
0.15 when grade < 3
0.35 when grade >= 3 & g2 < 13 & ploidy is diploid or tetraploid
0.36 when grade >= 3 & g2 >= 18
0.67 when grade >= 3 & g2 < 13 & ploidy is aneuploid
0.87 when grade >= 3 & g2 is 13 to 18
和
rpart.rules(fit, roundint=FALSE, clip.facs=TRUE)
给予
pgstat
0.15 when grade < 2.5
0.35 when grade >= 2.5 & g2 < 13 & diploid or tetraploid
0.36 when grade >= 2.5 & g2 >= 18
0.67 when grade >= 2.5 & g2 < 13 & aneuploid
0.87 when grade >= 2.5 & g2 is 13 to 18
有关更多示例,请参见 rpart.plot vignette。
答案 1 :(得分:0)
您可以使用partykit软件包中的list.rules.party()
函数和一些字符串格式。这是使用您的代码的示例。
data(stagec)
fit <- rpart(
formula = pgstat ~ age + eet + g2 + grade + gleason + ploidy,
data = stagec,
method = "class",
control = rpart.control(cp = 0.05)
)
party_obj <- as.party.rpart(fit, data = TRUE)
decisions <- partykit:::.list.rules.party(party_obj)
cat(paste(decisions, collapse = "\n"))
如您所见,您以相同的方式构建树模型。然后,将模型转换为参与方对象,然后使用list.rules.party()
函数提取决策字符串。进行一点格式化,您就会得到
grade < 2.5
grade >= 2.5 & g2 < 13.2 & ploidy %in% c("diploid", "tetraploid")
grade >= 2.5 & g2 < 13.2 & ploidy %in% c("aneuploid")
grade >= 2.5 & g2 >= 13.2 & g2 >= 17.91
grade >= 2.5 & g2 >= 13.2 & g2 < 17.91
作为结果。