使用Scikit Learn创建Amazon SageMaker超参数调整作业

时间:2018-07-17 21:42:00

标签: scikit-learn amazon-sagemaker

我想知道如何使用Amazon Sagemaker自动调整scikit学习随机森林模型。现在,我想调整一个称为“ max_depth”的超参数。我将先转储我的代码,然后再表达一些担忧。

文件:notebook.ipynb

estimator = sagemaker.estimator.Estimator(image, role,
              train_instance_count=1,
              train_instance_type='ml.m4.xlarge',
              output_path=output_location,
              sagemaker_session=sagemaker_session,
              )

hyperparameter_ranges = {'max_depth': IntegerParameter(20, 30)}
objective_metric_name = 'score'
metric_definitions = [{'Name': 'score', 'Regex': 'score: ([0-9\\.]+)'}]

tuner = HyperparameterTuner(estimator,
                        objective_metric_name,
                        hyperparameter_ranges,
                        metric_definitions,
                        max_jobs=9,
                        max_parallel_jobs=3)
tuner.fit({'train': train_data_location, 'test': test_data_location})

文件:train(位于Docker容器中)

def train():
    with open(param_path, 'r') as tc:
        hyperparams = json.load(tc)
    print("DEBUG VALUE: ", hyperparams)
    data, class = get_data() #abstraction
    X, y = train_data.drop(['class'], axis=1), train_data['class']
    clf = RandomForestClassifier()
    clf.fit(data, class)
    print("score: " + str(evaluate_model(clf)) + "\n")

我看到此代码有两个问题。首先,如果我在必要的路径下的名为hyperparameters.json的文件中放置了一个json对象{'max_value':2},则print语句将输出{},就好像该文件为空。

第2个问题是train()不允许超参数以任何形式或形式影响代码。据我所知,亚马逊没有关于tuner.fit()方法内部工作的文档。这意味着我无法弄清楚train()如何访问超参数进行测试。

我们非常感谢您的帮助,请告知我是否可以提供更多代码或澄清任何内容。

1 个答案:

答案 0 :(得分:0)

对于问题1,该服务将在/opt/ml/input/config/hyperparameters.json上为您编写hyperparameters.json -您无需自己编写。

对于问题2,tuner.fit(),fit()函数将启动SageMaker Tuning作业,该作业将运行多个SageMaker Training作业-每个作业都将调用您的train()函数。每个训练作业将获得不同的超参数集,因此您的train()函数的职责是简单地读取文件并使用其中的值来参数化算法。然后,您的算法应针对接收到的一组超参数发出特定的客观指标值和模型。

“调整”作业将查看每个成功完成的训练作业的客观指标,并从超参数搜索空间中找出最佳超参数以运行下一个作业或停止调整并发出找到的最佳模型/超参数,因此远。

我们已经开放了SageMaker Python SDK的源代码,因此您可以查看代码以了解更多详细信息:https://github.com/aws/sagemaker-python-sdk/blob/2fa160c44d92c9eacdb8f79265676eca42832233/src/sagemaker/tuner.py#L225

您可以在https://github.com/aws/sagemaker-python-sdk#sagemaker-automatic-model-tuning

上了解更多有关从Python SDK启动调优作业的信息。

希望这会有所帮助!