我为Decision树构建了一个deviance / splits图,如下所示。
data(mtcars)
cars <- mtcars[,1:4]
smp_size <- floor(0.75 * nrow(cars))
set.seed(100)
train_ind <- sample(seq_len(nrow(LoanData)), size = smp_size)
train <- LoanData[train_ind, ]
test <- LoanData[-train_ind, ]
#Deciscion tree model
library(tree)
car_tree <- tree(mpg ~., data=cars, mindev=0.003,mincut=2,minsize = 6)
#Check deviance
cv_tree <- cv.tree(car_tree,FUN=prune.tree, K=10)
plot(cv_tree$size,cv_tree$dev,type="b",xlab="splits",ylab="deviance",
main="deviance by splits")
我想在此图上叠加测试数据的偏差,以便观察之后的偏差开始再次增加。你能告诉我怎么做吗?
答案 0 :(得分:0)
要查看树的不同深度的准确性,需要修剪树,预测训练和测试结果,并评估训练和测试结果的准确性。
这是数据,提取了训练和测试子集
data(mtcars)
cars <- mtcars
smp_size <- floor(0.75 * nrow(cars))
set.seed(100)
train_ind <- sample(seq_len(nrow(cars)), size = smp_size)
train <- cars[train_ind, ]
test <- cars[-train_ind, ]
这是一个辅助函数,用于确定给定模型的训练和测试数据的准确性。您可能希望修改此项以包含其他验证估算值。
compare<-function(tr, train, test, dpth, rst=NULL) {
est.train <- predict(tr,train)
est.test <- predict(tr,test)
delta.train = est.train - train$mpg
delta.test = est.test - test$mpg
df <- data.frame(cor.train = cor(train$mpg,est.train),
cor.test = cor(test$mpg,est.test),
sd.train = sd(delta.train),
sd.test = sd(delta.test),
depth = dpth)
return(rbind(rst,df))
}
创建树
#Deciscion tree model
library(tree)
car_tree <- tree(mpg ~., data=train, mindev=0.003,mincut=2,minsize = 6)
打印树,确定深度(3)和最深的分裂节点(4:7)
car_tree
# Depth is 3 and the 3rd level nodes are 4:7
获得深度为3的结果
rslts<-compare(car_tree,train,test,3)
现在修剪树,然后打印它。注意下一个深度是2,最深的节点是2:3
(car_tree_sn_1 <- snip.tree(car_tree,c(4:7)))
# Depth is 2 and the 2nd level nodes are 2:3
获得深度为2的结果
rslts<-compare(car_tree_sn_1,train,test,2,rslts)
现在修剪树,然后打印它。请注意,下一个深度为1,并且没有拆分节点
(car_tree_sn_2 <- snip.tree(car_tree,c(2:3)))
# Depth is 1 and there are no split nodes
获得深度为1的结果
rslts<-compare(car_tree_sn_2,train,test,1,rslts)
绘制准确度估算值
plot(rslts$depth,rslts$cor.train,type="b",xlab="splits",ylab="Correlation Coefficient",
main="Correlation by splits",log="y",ylim=c(.5,1))
lines(rslts$depth,rslts$cor.test,type="b",col="red")
plot(rslts$depth,rslts$sd.train,type="b",xlab="splits",ylab="Standard Deviation",
main="Correlation by splits",log="y",ylim=c(.5,5))
lines(rslts$depth,rslts$sd.test,type="b",col="red")
还有其他树模型。这是一个rpart的例子。
# Regression Tree Example
data(mtcars)
cars <- mtcars[,1:6]
library(rpart)
# grow tree
fit <- rpart(mpg ~., data=cars, control=list(minsplit = 1))
printcp(fit) # display the results
plotcp(fit) # visualize cross-validation results
summary(fit) # detailed summary of splits
# create additional plots
par(mfrow=c(1,2)) # two plots on one page
rsq.rpart(fit) # visualize cross-validation results
# plot tree
library(rpart.plot)
prp(fit,extra=101,branch.type=3)
plot(fit, uniform=TRUE,
main="Regression Tree for Mileage ")
text(fit, use.n=TRUE, all=TRUE, cex=.8)