获取R中rpart / ctree包的每行预测数据集的决策树规则/路径模式

时间:2015-04-14 03:08:43

标签: r decision-tree rpart

我使用rpartctree在R中构建了决策树模型。 我还使用构建的模型预测了一个新的数据集,并得到了预测的概率和类。

但是,我想在单个字符串中提取规则/路径,以跟踪每个观察(在预测数据集中)。以表格格式存储这些数据,我可以自动解释预测,而无需打开R。

这意味着我想要关注。

ObsID   Probability   PredictedClass   PathFollowed 
    1          0.68             Safe   CarAge < 10 & Country = Germany & Type = Compact & Price < 12822.5
    2          0.76             Safe   CarAge < 10 & Country = Korea & Type = Compact & Price > 12822.5
    3          0.88           Unsafe   CarAge > 10 & Type = Van & Country = USA & Price > 15988

我正在寻找的代码是

library(rpart)
fit <- rpart(Reliability~.,data=car.test.frame)

这可能需要扩展到多行

predResults <- predict(fit, newdata = newcar, type= "GETPATTERNS")

1 个答案:

答案 0 :(得分:7)

partykit包有一个函数.list.rules.party(),它当前未被导出,但可以用来做你想做的事情。我们还没有导出它的主要原因是它的输出类型可能会在未来的版本中发生变化。

要获得您在上面描述的预测,您可以这样做:

pathpred <- function(object, ...)
{
  ## coerce to "party" object if necessary
  if(!inherits(object, "party")) object <- as.party(object)

  ## get standard predictions (response/prob) and collect in data frame
  rval <- data.frame(response = predict(object, type = "response", ...))
  rval$prob <- predict(object, type = "prob", ...)

  ## get rules for each node
  rls <- partykit:::.list.rules.party(object)

  ## get predicted node and select corresponding rule
  rval$rule <- rls[as.character(predict(object, type = "node", ...))]

  return(rval)
}

使用iris数据和rpart()

的插图
library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.90740741     0.09259259
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                           rule
## 1                          Petal.Length < 2.45
## 51   Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75

(此处仅显示每个物种的第一次观察以简洁。这对应于索引1,51和101.)

使用ctree()

ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.97826087     0.02173913
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                                              rule
## 1                                             Petal.Length <= 1.9
## 51  Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101                        Petal.Length > 1.9 & Petal.Width > 1.7