R H2O网格搜索:如何训练新数据的顶级模型?

时间:2018-01-07 03:09:15

标签: r h2o

运行超参数搜索并从网格中提取最佳模型后,是否可以使用模型对象训练新数据集?我现在看到的唯一方法是使用最佳模型的参数手动创建对火车功能的调用(例如h2o.gbm()),但这非常麻烦。

1 个答案:

答案 0 :(得分:1)

checkpoint参数可以满足您的需求,从原始模型进一步训练模型。

此功能适用于gbm包中的random forestdeep learningh2o

以下示例代码从http://s3.amazonaws.com/h2o-release/h2o/master/3689/docs-website/h2o-docs/data-science/algo-params/checkpoint.html

复制
library(h2o)
h2o.init()

# import the cars dataset:
# this dataset is used to classify whether or not a car is economical based on
# the car's displacement, power, weight, and acceleration, and the year it was made
cars <- h2o.importFile("https://s3.amazonaws.com/h2o-public-test-data/smalldata/junit/cars_20mpg.csv")

# convert response column to a factor
cars["economy_20mpg"] <- as.factor(cars["economy_20mpg"])

# set the predictor names and the response column name
predictors <- c("displacement","power","weight","acceleration","year")
response <- "economy_20mpg"

# split into train and validation sets
cars.split <- h2o.splitFrame(data = cars,ratios = 0.8, seed = 1234)
train <- cars.split[[1]]
valid <- cars.split[[2]]

# build a GBM with 1 tree (ntrees = 1) for the first model:
cars_gbm <- h2o.gbm(x = predictors, y = response, training_frame = train,
                    validation_frame = valid, ntrees = 1, seed = 1234)

# print the auc for the validation data
print(h2o.auc(cars_gbm, valid = TRUE))

# re-start the training process on a saved GBM model using the ‘checkpoint‘ argument:
# the checkpoint argument requires the model id of the model on which you wish to continue building
# get the model's id from "cars_gbm" model using `cars_gbm@model_id`
# the first model has 1 tree, let's continue building the GBM with an additional 49 more trees, so set ntrees = 50

# to see how many trees the original model built you can look at the `ntrees` attribute
print(paste("Number of trees built for cars_gbm model:", cars_gbm@allparameters$ntrees))

# build and train model with 49 additional trees for a total of 50 trees:
cars_gbm_continued <- h2o.gbm(x = predictors, y = response, training_frame = train,
                    validation_frame = valid, checkpoint = cars_gbm@model_id, ntrees = 50, seed = 1234)

# print the auc for the validation data
print(h2o.auc(cars_gbm_continued, valid = TRUE))

# you can also use checkpointing to pass in a new dataset (see options above for parameters you cannot change)
# simply change out the training and validation frames with your new dataset

编辑(根据@ Edward的评论如下:)

h2o.grid将返回一系列模型,您可以获得最好的模型汉德尔。所有参数都保存在模型汉德尔中,然后您可以将参数应用于新模型。

grid <- h2o.getGrid(h2o.grid@grid_id,sort_by = "auc",decreasing=TRUE)
model.h2o <- h2o.getModel(grid@model_ids[[1]])

model@allparameters包含所有使用的参数,您可以使用这些参数创建新模型和新数据。