从cva.glmnet对象中提取最佳参数

时间:2019-02-21 09:48:23

标签: r cross-validation glmnet

我敢肯定,运行cva.glmnet后,有一种优雅的方法可以提取最佳的alpha和lambda,但是我无法找到它。

这是我在此期间使用的代码。

谢谢

library(data.table);library(glmnetUtils);library(useful)

# make some dummy data

data(iris)

x <- useful::build.x(data = iris,formula = Sepal.Length ~ .)
y <- iris$Sepal.Length

# run cv for alpha in c(0,0.5,1)

output.of.cva.glmnet <- cva.glmnet(x=x,y=y,alpha = c(0,0.5,1))

# extract the best parameters

number.of.alphas.tested <- length(output.of.cva.glmnet$alpha)

cv.glmnet.dt <- data.table()

for (i in 1:number.of.alphas.tested){
  glmnet.model <- output.of.cva.glmnet$modlist[[i]]
  min.mse <-  min(glmnet.model$cvm)
  min.lambda <- glmnet.model$lambda.min
  alpha.value <- output.of.cva.glmnet$alpha[i]
  new.cv.glmnet.dt <- data.table(alpha=alpha.value,min_mse=min.mse,min_lambda=min.lambda)
  cv.glmnet.dt <- rbind(cv.glmnet.dt,new.cv.glmnet.dt)
}

best.params <- cv.glmnet.dt[which.min(cv.glmnet.dt$min_mse)]

enter image description here

1 个答案:

答案 0 :(得分:3)

基于我在GitHub上阅读的一个线程,作者希望人们使用plot(fit)而不是仅输出最佳参数。但是,这并非总是可能的,尤其是在涉及交叉验证时。这些帮助程序功能可能是一个不错的解决方法。

# Train model.
fit <- cva.glmnet(X, y)

# Get alpha.
get_alpha <- function(fit) {
  alpha <- fit$alpha
  error <- sapply(fit$modlist, function(mod) {min(mod$cvm)})
  alpha[which.min(error)]
}

# Get all parameters.
get_model_params <- function(fit) {
  alpha <- fit$alpha
  lambdaMin <- sapply(fit$modlist, `[[`, "lambda.min")
  lambdaSE <- sapply(fit$modlist, `[[`, "lambda.1se")
  error <- sapply(fit$modlist, function(mod) {min(mod$cvm)})
  best <- which.min(error)
  data.frame(alpha = alpha[best], lambdaMin = lambdaMin[best],
             lambdaSE = lambdaSE[best], eror = error[best])
}