来自rpart的混淆矩阵

时间:2014-01-22 04:04:00

标签: r machine-learning classification decision-tree confusion-matrix

我不能为我的生活弄清楚如何在rpart上计算混淆矩阵。

这就是我所做的:

set.seed(12345)
UBANK_rand <- UBank[order(runif(1000)), ]
UBank_train <- UBank_rand[1:900, ]
UBank_test  <- UBank_rand[901:1000, ]


dim(UBank_train)
dim(UBank_test)

#Build the formula for the Decision Tree
UB_tree <- Personal.Loan ~ Experience + Age+ Income +ZIP.Code + Family + CCAvg + Education

#Building the Decision Tree from Test Data
UB_rpart <- rpart(UB_tree, data=UBank_train)

现在,我认为我会做类似

的事情
table(predict(UB_rpart, UBank_test, UBank_Test$Default))

但这并没有给我一个混乱矩阵。

2 个答案:

答案 0 :(得分:11)

您没有提供可重现的示例,因此我将创建一个合成数据集:

set.seed(144)
df = data.frame(outcome = as.factor(sample(c(0, 1), 100, replace=T)),
                x = rnorm(100))

带有predict的{​​{1}}模型的rpart函数将返回每个观察的预测类。

type="class"

最后,您可以通过在预测和真实结果之间运行library(rpart) mod = rpart(outcome ~ x, data=df) pred = predict(mod, type="class") table(pred) # pred # 0 1 # 51 49 来构建混淆矩阵:

table

答案 1 :(得分:-1)

你可以尝试

pred <- predict(UB_rpart, UB_test) confusionMatrix(pred, UB_test$Personal.Loan)