有人可以从Luis Torgo(DMwR package)解释这段代码:
cv.rpart <- function(form, train, test, ...) {
m <- rpartXse(form, train, ...)
p <- predict(m, test)
mse <- mean( (p-resp(form,test))^2 )
c( nmse=mse/mean( (mean(resp(form,train))-resp(form,test))^2 ) )
}
cv.lm <- function(form, train, test, ...) {
m <- lm(form, train,...)
p <- predict(m, test)
p <- ifelse(p<0, 0, p)
mse <- mean( (p-resp(form,test))^2 )
c( nmse=mse/mean( (mean(resp(form,train))-resp(form,test))^2 ) )
}
res <- experimentalComparison(c(dataset(a1 ~ .,clean.algae[,1:12],'a1')),
c(variants('cv.lm'), variants('cv.rpart',se=c(0,0.5,1))),
cvSettings(3,10,1234)
)
experimentalComparison
如何使用cv.rpart
和cv.lm
?
答案 0 :(得分:1)
cv.lm
和cv.rpart
分别对线性模型和决策树模型执行交叉验证。对于决策树,在experimentalComparison
中,我们还指定了不同的复杂性参数。
如果你在最后运行plot(res)
,因为Torgo在他的代码中有它,你可以看到4个模型(1 lm + 3 rpart)的错误的箱线图。
我评论了以下几行。
# this function combines training, cross-validation, pruning, prediction,
# and metric calculation
cv.rpart <- function(form, train, test, ...) {
# rpartXse grows a tree and calculates the cross-validation error
# at each node. It then determines the best tree based on the
# the results of this cross-validation.
# Torgo details how the optimal tree based on
# cross-validation results is chosen
# earlier in his code
m <- rpartXse(form, train, ...)
# use m to predict on test set
p <- predict(m, test)
# calculates normalized mean square error
# Refer https://rem.jrc.ec.europa.eu/RemWeb/atmes2/20b.htm
# for details on NMSE
mse <- mean( (p-resp(form,test))^2 )
c( nmse=mse/mean( (mean(resp(form,train))-resp(form,test))^2 ) )
}
cv.lm <- function(form, train, test, ...) {
m <- lm(form, train,...)
p <- predict(m, test)
p <- ifelse(p<0, 0, p)
mse <- mean( (p-resp(form,test))^2 )
c( nmse=mse/mean( (mean(resp(form,train))-resp(form,test))^2 ) )
}
# experimental comparison is designed to create numerous models
# based on parameters you provide it
# Arguments of experimentalComparison function are
# Dataset class object
# learner class object (contains the learning systems that will be used)
# settings class object
# These datatypes are unique to the DMwR package
# dataset is a function that creates a dataset object (a list)
# each element of the list contains the response variable
# and the actual data
res <- experimentalComparison(
c(dataset(a1 ~ .,clean.algae[,1:12],'a1')),
c(variants('cv.lm'),
# se specifies the number of standard errors to
# use in the post-pruning of the tree
variants('cv.rpart',se=c(0,0.5,1))),
# cvSettings specifies 3 repetitions of 10-fold
# cross-validation
# with a seed of 1234
cvSettings(3,10,1234)
)
summary(res)
为您提供每种模型的交叉验证结果的基本统计信息。