mlr:使用验证集调整模型参数

时间:2018-08-04 17:05:05

标签: r hyperparameters mlr

仅在我的机器学习工作流程中切换为mlr。我想知道是否有可能使用单独的验证集来调整超参数。据我的最低理解,makeResampleDescmakeResampleInstance仅接受来自训练数据的重采样。

我的目标是使用验证集调整参数,并使用测试集测试最终模型。这是为了防止过度拟合和知识泄漏。

这是我在代码方面所做的:

## Create training, validation and test tasks
train_task <- makeClassifTask(data = train_data, target = "y", positive = 1)
validation_task <- makeClassifTask(data = validation_data, target = "y")
test_task <- makeClassifTask(data = test_data, target = "y")

## Attempt to tune parameters with separate validation data
tuned_params <- tuneParams(
    task = train_task,
    resampling = makeResampleInstance("Holdout", task = validation_task),
    ...
)

从错误消息中,看来评估仍在尝试从训练集中重新采样:

  

00001:resample.fun中的错误(学习者2,任务,重采样,度量=   度量,:数据集的大小:19454和重采样实例:   1666333不一样!

有人知道我应该怎么做吗?我是否以正确的方式设置了所有内容?

1 个答案:

答案 0 :(得分:0)

[自2019/03/27起更新]

遵循@ jakob-r的评论,最后了解@LarsKotthoff的建议,这就是我所做的:

## Create combined training data
train_task_data <- rbind(train_data, validation_data)

## Create learner, training task, etc.
xgb_learner <- makeLearner("classif.xgboost", predict.type = "prob")
train_task <- makeClassifTask(data = train_task_data, target = "y", positive = 1)

## Tune hyperparameters
tune_wrapper <- makeTuneWrapper(
  learner = xgb_learner,
  resampling = makeResampleDesc("Holdout"),
  measures = ...,
  par.set = ...,
  control = ...
)
model_xgb <- train(tune_wrapper, train_task)

这是我在@LarsKotthoff的评论之后所做的事情。假设您有两个分别用于训练(train_data)和验证(validation_data)的数据集:

## Create combined training data
train_task_data <- rbind(train_data, validation_data)
size <- nrow(train_task_data)
train_ind <- seq_len(nrow(train_data))
validation_ind <- seq.int(max(train_ind) + 1, size)

## Create training task
train_task <- makeClassifTask(data = train_task_data, target = "y", positive = 1)

## Tune hyperparameters
tuned_params <- tuneParams(
    task = train_task,
    resampling = makeFixedHoldoutInstance(train_ind, validation_ind, size),
    ...
)

优化超参数集后,您可以构建最终模型并针对测试数据集进行测试。

注意:我必须从GitHub安装最新的开发版本(截至2018/08/06)。当前的CRAN版本(2.12.1)在我调用makeFixedHoldoutInstance()时抛出错误,即

  

对“ discrete.names”的声明失败:必须为“逻辑标志”类型,   不是“ NULL”。