我正在尝试使用rpart构建分类树模型。 测试数据框非常简单,只包含10行中的两个布尔变量。 隐藏的逻辑也很简单:当x为FALSE时,y必须为FALSE。当x为TRUE时,y有60%的可能性为TRUE。 所以我想象rpart会在x上进行一次拆分以提高节点纯度。但它保留在根节点,根本不分裂。有人请指教吗?
> df <- data.frame(x=rep(c(FALSE,TRUE), each=5), y=c(rep(FALSE,7), rep(TRUE,3)))
> df
x y
1 FALSE FALSE
2 FALSE FALSE
3 FALSE FALSE
4 FALSE FALSE
5 FALSE FALSE
6 TRUE FALSE
7 TRUE FALSE
8 TRUE TRUE
9 TRUE TRUE
10 TRUE TRUE
> rpart(y~x, method='class', data=df)
n= 10
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 10 3 FALSE (0.7000000 0.3000000) *
答案 0 :(得分:1)
正如我在评论中所说,这是为了避免过度拟合。形式上,有minsplit
参数,它被预设为20但可以调整以给出你想要的结果:
> library(rpart)
> df <- data.frame(x=rep(c(FALSE,TRUE), each=5), y=c(rep(FALSE,7), rep(TRUE,3)))
> rpart(y ~ x, data=df, minsplit=2)
n= 10
node), split, n, deviance, yval
* denotes terminal node
1) root 10 2.1 0.3
2) x< 0.5 5 0.0 0.0 *
3) x>=0.5 5 1.2 0.6 *
在
中找到更多关于avoice overfitting的论据(即cp
和maxdepth
)
help(rpart.control)
编辑:使用method =“class”,输出将更改为
> rpart(y ~ x, data=df, minsplit=2, method="class")
n= 10
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 10 3 FALSE (0.7000000 0.3000000)
2) x< 0.5 5 0 FALSE (1.0000000 0.0000000) *
3) x>=0.5 5 2 TRUE (0.4000000 0.6000000) *