如何在R中绘制学习曲线?

时间:2016-07-26 06:01:31

标签: regression linear-regression

我想在我的应用程序中绘制学习曲线。

示例曲线图像如下所示。

enter image description here

学习曲线是以下方差之间的曲线,

  • X轴:样本数(训练集大小)。
  • Y轴:错误(RSS / J(theta)/成本函数)

它有助于观察我们的模型是否具有高偏差或高偏差问题。

R中是否有任何包可以帮助获得这个情节?

1 个答案:

答案 0 :(得分:2)

你可以使用优秀的Caret包制作这样的情节。 Customizing the tuning process部分将非常有用。

此外,您可以查看Joseph Rickert撰写的关于R-Bloggers的精彩博客文章。它们的标题为"Why Big Data? Learning Curves""Learning from Learning Curves"

<强>更新
我刚刚在这个问题Plot learning curves with caret package and R上发了一篇文章。我想我的回答对你更有用。为了方便起见,我在这里用R绘制学习曲线重现了相同的答案。但是,我使用了流行的caret包来训练我的模型并获得训练和测试集的RMSE错误。

# set seed for reproducibility
set.seed(7)

# randomize mtcars
mtcars <- mtcars[sample(nrow(mtcars)),]

# split iris data into training and test sets
mtcarsIndex <- createDataPartition(mtcars$mpg, p = .625, list = F)
mtcarsTrain <- mtcars[mtcarsIndex,]
mtcarsTest <- mtcars[-mtcarsIndex,]

# create empty data frame 
learnCurve <- data.frame(m = integer(21),
                     trainRMSE = integer(21),
                     cvRMSE = integer(21))

# test data response feature
testY <- mtcarsTest$mpg

# Run algorithms using 10-fold cross validation with 3 repeats
trainControl <- trainControl(method="repeatedcv", number=10, repeats=3)
metric <- "RMSE"

# loop over training examples
for (i in 3:21) {
    learnCurve$m[i] <- i

    # train learning algorithm with size i
    fit.lm <- train(mpg~., data=mtcarsTrain[1:i,], method="lm", metric=metric,
             preProc=c("center", "scale"), trControl=trainControl)        
    learnCurve$trainRMSE[i] <- fit.lm$results$RMSE

    # use trained parameters to predict on test data
    prediction <- predict(fit.lm, newdata = mtcarsTest[,-1])
    rmse <- postResample(prediction, testY)
    learnCurve$cvRMSE[i] <- rmse[1]
}

pdf("LinearRegressionLearningCurve.pdf", width = 7, height = 7, pointsize=12)

# plot learning curves of training set size vs. error measure
# for training set and test set
plot(log(learnCurve$trainRMSE),type = "o",col = "red", xlab = "Training set size",
          ylab = "Error (RMSE)", main = "Linear Model Learning Curve")
lines(log(learnCurve$cvRMSE), type = "o", col = "blue")
legend('topright', c("Train error", "Test error"), lty = c(1,1), lwd = c(2.5, 2.5),
       col = c("red", "blue"))

dev.off()

输出图如下所示:
MtCarsLearningCurve.png