很好地打印决策树/自定义控件[r]

时间:2018-06-12 18:27:09

标签: r decision-tree rpart party

我想在文本中打印一个决策树。例如,我可以打印树对象本身:

library(rpart)

f = as.formula('Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species')
fit = rpart(f, data = iris, control = rpart.control(xval = 3))

fit

产量

n= 150 

node), split, n, deviance, yval
      * denotes terminal node

 1) root 150 102.1683000 5.843333  
   2) Petal.Length< 4.25 73  13.1391800 5.179452  
     4) Petal.Length< 3.4 53   6.1083020 5.005660  
       8) Sepal.Width< 3.25 20   1.0855000 4.735000 *
       9) Sepal.Width>=3.25 33   2.6696970 5.169697 *
... # omitted

partykit打印整洁:

library(partykit)

as.party(fit)

产量

Model formula:
Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species

Fitted party:
[1] root
|   [2] Petal.Length < 4.25
|   |   [3] Petal.Length < 3.4
|   |   |   [4] Sepal.Width < 3.25: 4.735 (n = 20, err = 1.1)
|   |   |   [5] Sepal.Width >= 3.25: 5.170 (n = 33, err = 2.7)
|   |   [6] Petal.Length >= 3.4: 5.640 (n = 20, err = 1.2)
...# omitted

Number of inner nodes:    6
Number of terminal nodes: 7

有没有办法让我有更多的控制权?例如,我不想打印nerr,或者希望打印标准偏差而不是err

2 个答案:

答案 0 :(得分:1)

不是一个非常优雅的答案,但如果您只想摆脱err=CO = capture.output(print(as.party(fit))) CO2 = sub("\\(.*\\)", "", CO) cat(paste(CO2, collapse="\n")) Model formula: Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species Fitted party: [1] root | [2] Petal.Length < 4.25 | | [3] Petal.Length < 3.4 | | | [4] Sepal.Width < 3.25: 4.735 | | | [5] Sepal.Width >= 3.25: 5.170 | | [6] Petal.Length >= 3.4: 5.640 | [7] Petal.Length >= 4.25 ,您可以捕获输出并进行编辑。

this.filters$.subscribe( res => {
  res.pipe(
    switchMap(filter => this.service.getItems(filter)
  )
})

我不确定您要插入什么标准偏差,但我希望您可以用同样的方式编辑它。

答案 1 :(得分:0)

print()对象的party方法非常灵活,可以通过各种面板功能和自定义进行控制。有关概述,请参见?print.party。但是,该文档有些简短和技术性。

在您的情况下,最简单的解决方案是设置响应y,大小权重w(在您的情况下默认为全1)和所需数量{{ 1}}:

digits

然后您可以将其传递给您的myfun <- function(y, w, digits = 2) { n <- sum(w) m <- weighted.mean(y, w) s <- sqrt(weighted.mean((y - m)^2, w) * n/(n - 1)) sprintf("%s (serr = %s)", round(m, digits = digits), round(s, digits = digits)) } 呼叫:

print()