实验比较和变体

时间:2015-04-17 08:31:46

标签: r

有人可以从Luis Torgo(DMwR package)解释这段代码:

cv.rpart <- function(form, train, test, ...) {
  m   <- rpartXse(form, train, ...)
  p   <- predict(m, test)
  mse <- mean( (p-resp(form,test))^2 )
  c(  nmse=mse/mean( (mean(resp(form,train))-resp(form,test))^2 )  )
}
cv.lm <- function(form, train, test, ...) {
  m   <- lm(form, train,...)
  p   <- predict(m, test)
  p   <- ifelse(p<0, 0, p)
  mse <- mean( (p-resp(form,test))^2 )
  c(  nmse=mse/mean( (mean(resp(form,train))-resp(form,test))^2 )  )
}

res <- experimentalComparison(c(dataset(a1 ~ .,clean.algae[,1:12],'a1')),
                              c(variants('cv.lm'), variants('cv.rpart',se=c(0,0.5,1))),
                              cvSettings(3,10,1234)
                              )

experimentalComparison如何使用cv.rpartcv.lm

1 个答案:

答案 0 :(得分:1)

cv.lmcv.rpart分别对线性模型和决策树模型执行交叉验证。对于决策树,在experimentalComparison中,我们还指定了不同的复杂性参数。

如果你在最后运行plot(res),因为Torgo在他的代码中有它,你可以看到4个模型(1 lm + 3 rpart)的错误的箱线图。

我评论了以下几行。

# this function combines training, cross-validation, pruning, prediction,
# and metric calculation
cv.rpart <- function(form, train, test, ...) {
  # rpartXse grows a tree and calculates the cross-validation error
  # at each node.  It then determines the best tree based on the 
  # the results of this cross-validation.
  # Torgo details how the optimal tree based on 
  # cross-validation results is chosen
  # earlier in his code
  m   <- rpartXse(form, train, ...)   
  # use m to predict on test set
  p   <- predict(m, test)
  # calculates normalized mean square error
  # Refer https://rem.jrc.ec.europa.eu/RemWeb/atmes2/20b.htm
  # for details on NMSE
  mse <- mean( (p-resp(form,test))^2 )
  c(  nmse=mse/mean( (mean(resp(form,train))-resp(form,test))^2 )  )
}


cv.lm <- function(form, train, test, ...) {
  m   <- lm(form, train,...)
  p   <- predict(m, test)
  p   <- ifelse(p<0, 0, p)
  mse <- mean( (p-resp(form,test))^2 )
  c(  nmse=mse/mean( (mean(resp(form,train))-resp(form,test))^2 )  )
}

# experimental comparison is designed to create numerous models 
# based on parameters you provide it

# Arguments of experimentalComparison function are
# Dataset class object
# learner class object (contains the learning systems that will be used)
# settings class object
# These datatypes are unique to the DMwR package
# dataset is a function that creates a dataset object (a list)
# each element of the list contains the response variable
# and the actual data
res <- experimentalComparison(
             c(dataset(a1 ~ .,clean.algae[,1:12],'a1')),
             c(variants('cv.lm'), 
              # se specifies the number of standard errors to
              # use in the post-pruning of the tree
              variants('cv.rpart',se=c(0,0.5,1))),
             # cvSettings specifies 3 repetitions of 10-fold
             # cross-validation
             # with a seed of 1234
             cvSettings(3,10,1234)
                              )

summary(res)为您提供每种模型的交叉验证结果的基本统计信息。