如果我对这个问题进行讨论,我会提前道歉,因为我对R和整体统计分析都很陌生。
我使用party
库生成了一个条件推理树
当我plot(my_tree, type = "simple")
我得到这样的结果时:
当我print(my_tree)
时,我得到了这样的结果:
1) SOME_VALUE <= 2.5; criterion = 1, statistic = 1306.478
2) SOME_VALUE <= -10.5; criterion = 1, statistic = 173.416
3) SOME_VALUE <= -16; criterion = 1, statistic = 19.385
4)* weights = 275
3) SOME_VALUE > -16
5)* weights = 261
2) SOME_VALUE > -10.5
6) SOME_VALUE <= -2.5; criterion = 1, statistic = 24.094
7) SOME_VALUE <= -6.5; criterion = 0.974, statistic = 4.989
8)* weights = 346
7) SOME_VALUE > -6.5
9)* weights = 563
6) SOME_VALUE > -2.5
10)* weights = 442
1) SOME_VALUE > 2.5
11) SOME_VALUE <= 10; criterion = 1, statistic = 225.148
12) SOME_VALUE <= 6.5; criterion = 1, statistic = 18.789
13)* weights = 648
12) SOME_VALUE > 6.5
14)* weights = 473
11) SOME_VALUE > 10
15) SOME_VALUE <= 16; criterion = 1, statistic = 51.729
16)* weights = 595
15) SOME_VALUE > 16
17) SOME_VALUE <= 23.5; criterion = 0.997, statistic = 8.931
18)* weights = 488
17) SOME_VALUE > 23.5
19)* weights = 365
我更喜欢print
的输出,但似乎缺少y = (0.96, 0.04)
值。
理想情况下,我希望我的输出看起来像这样:
1) SOME_VALUE <= 2.5; criterion = 1, statistic = 1306.478
2) SOME_VALUE <= -10.5; criterion = 1, statistic = 173.416
3) SOME_VALUE <= -16; criterion = 1, statistic = 19.385
4)* weights = 275; y = (0.96, 0.04)
3) SOME_VALUE > -16
5)* weights = 261; y = (0.831, 0.169)
2) SOME_VALUE > -10.5
...
我该如何完成这项工作?
答案 0 :(得分:3)
可以使用partykit
包(party
的后继者)执行此操作,但即使在那里也需要一些黑客攻击。原则上,print()
函数可以通过内部和终端节点等的面板函数进行自定义。但即使是像这样看似简单的任务,它们看起来也不是很好。
由于您似乎使用了具有双变量响应的树,让我们考虑这个简单(尽管不是很有意义)可重现的示例:
library("partykit")
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone + Wind ~ ., data = airq)
对于内部节点,我们假设我们只想显示每个节点的$info
中可用的p值。我们可以通过以下方式格式化:
ip <- function(node) formatinfo_node(node,
prefix = " ",
FUN = function(info) paste0("[p = ", format.pval(info$p.value), "]")
)
对于终端节点,我们想要显示观察的数量(假设没有使用weights
)和平均响应。两者都在小表中预先计算,然后通过每个节点的$id
访问:
n <- table(ct$fitted[["(fitted)"]])
m <- aggregate(ct$fitted[["(response)"]], list(ct$fitted[["(fitted)"]]), mean)
m <- apply(m[, -1], 1, function(x) paste(round(x, digits = 3), collapse = ", "))
names(m) <- names(n)
然后通过以下方式定义面板功能:
tp <- function(node) formatinfo_node(node,
prefix = ": ",
FUN = function(info) paste0(
"n = ", n[as.character(node$id)],
", y = (", m[as.character(node$id)], ")"
)
)
要在print()
方法中应用此功能,我们需要直接调用print.party()
,因为当前print.constparty()
未正确传递此信息。 (我们必须在partykit
包中修复此问题。)
print.party(ct, inner_panel = ip, terminal_panel = tp)
## [1] root
## | [2] Temp <= 82 [p = 0.0044842]
## | | [3] Temp <= 77: n = 52, y = (18.615, 11.562)
## | | [4] Temp > 77: n = 27, y = (41.815, 9.737)
## | [5] Temp > 82: n = 37, y = (75.405, 7.565)
这有望接近您想要做的事情,并且应该为您提供进一步修改的模板。