我使用rpart
和ctree
在R中构建了决策树模型。
我还使用构建的模型预测了一个新的数据集,并得到了预测的概率和类。
但是,我想在单个字符串中提取规则/路径,以跟踪每个观察(在预测数据集中)。以表格格式存储这些数据,我可以自动解释预测,而无需打开R。
这意味着我想要关注。
ObsID Probability PredictedClass PathFollowed
1 0.68 Safe CarAge < 10 & Country = Germany & Type = Compact & Price < 12822.5
2 0.76 Safe CarAge < 10 & Country = Korea & Type = Compact & Price > 12822.5
3 0.88 Unsafe CarAge > 10 & Type = Van & Country = USA & Price > 15988
我正在寻找的代码是
library(rpart)
fit <- rpart(Reliability~.,data=car.test.frame)
这可能需要扩展到多行
predResults <- predict(fit, newdata = newcar, type= "GETPATTERNS")
答案 0 :(得分:7)
partykit
包有一个函数.list.rules.party()
,它当前未被导出,但可以用来做你想做的事情。我们还没有导出它的主要原因是它的输出类型可能会在未来的版本中发生变化。
要获得您在上面描述的预测,您可以这样做:
pathpred <- function(object, ...)
{
## coerce to "party" object if necessary
if(!inherits(object, "party")) object <- as.party(object)
## get standard predictions (response/prob) and collect in data frame
rval <- data.frame(response = predict(object, type = "response", ...))
rval$prob <- predict(object, type = "prob", ...)
## get rules for each node
rls <- partykit:::.list.rules.party(object)
## get predicted node and select corresponding rule
rval$rule <- rls[as.character(predict(object, type = "node", ...))]
return(rval)
}
使用iris
数据和rpart()
:
library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
## response prob.setosa prob.versicolor prob.virginica
## 1 setosa 1.00000000 0.00000000 0.00000000
## 51 versicolor 0.00000000 0.90740741 0.09259259
## 101 virginica 0.00000000 0.02173913 0.97826087
## rule
## 1 Petal.Length < 2.45
## 51 Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75
(此处仅显示每个物种的第一次观察以简洁。这对应于索引1,51和101.)
使用ctree()
:
ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
## response prob.setosa prob.versicolor prob.virginica
## 1 setosa 1.00000000 0.00000000 0.00000000
## 51 versicolor 0.00000000 0.97826087 0.02173913
## 101 virginica 0.00000000 0.02173913 0.97826087
## rule
## 1 Petal.Length <= 1.9
## 51 Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101 Petal.Length > 1.9 & Petal.Width > 1.7