在Amazon Sagemaker中对自定义代码进行增量培训

时间:2019-11-28 11:47:23

标签: python tensorflow amazon-sagemaker

我正在迈向amazon sagemaker的第一步。我正在使用脚本模式来训练分类算法。培训还不错,但是我无法进行增量培训。我想用新数据再次训练相同的模型。这是我做的。这是我的脚本:

import sagemaker
from sagemaker.tensorflow import TensorFlow
from sagemaker import get_execution_role

bucket = 'sagemaker-blablabla'
train_data = 's3://{}/{}'.format(bucket,'train')
validation_data = 's3://{}/{}'.format(bucket,'test')

s3_output_location = 's3://{}'.format(bucket)

tf_estimator = TensorFlow(entry_point='main.py', 
                          role=get_execution_role(),
                          train_instance_count=1, 
                          train_instance_type='ml.p2.xlarge',
                          framework_version='1.12', 
                          py_version='py3',
                          output_path=s3_output_location)

inputs = {'train': train_data, 'test': validation_data}
tf_estimator.fit(inputs)

入口点是我的自定义keras代码,我将该代码适配为从脚本中接收参数。 现在,培训已成功完成,并且在我的s3存储桶中有model.tar.gz。我想再次训练,但是我不清楚该怎么做。我尝试过

trained_model = 's3://sagemaker-blablabla/sagemaker-tensorflow-scriptmode-2019-11-27-12-01-42-300/output/model.tar.gz'

tf_estimator = sagemaker.estimator.Estimator(image_name='blablabla-west-1.amazonaws.com/sagemaker-tensorflow-scriptmode:1.12-gpu-py3', 
                                              role=get_execution_role(),
                                              train_instance_count=1, 
                                              train_instance_type='ml.p2.xlarge',
                                              output_path=s3_output_location,
                                              model_uri = trained_model)

inputs = {'train': train_data, 'test': validation_data}

tf_estimator.fit(inputs)

不起作用。首先,我不知道如何检索训练图像的名称(为此,我在aws控制台中寻找了该图像,但是我想应该有一个更聪明的解决方案),其次,此代码引发关于条目的异常要点,但是据我了解,当我用就绪图像进行增量学习时,我不需要它。 我肯定错过了重要的东西,有什么帮助吗?谢谢!

1 个答案:

答案 0 :(得分:0)

增量训练是内置Image Classifier and Object Detector的一项固有功能。对于自定义代码,开发人员有责任编写增量训练逻辑并验证其有效性。这是一条可能的路径:

  1. 使用fit中传递的数据通道之一来加载模型状态(需对工件进行微调)
  2. 在您的代码中,检查是否填充了模型状态通道 与文物。如果是,请从该状态实例化模型 并继续训练。这是特定于框架的,您可以采取 必要的预防措施,以避免忘记以前的经验。

某些框架为增量学习提供了更好的支持。例如,某些sklearn模型提供了一种incremental_fit方法。对于DL框架,从检查点继续进行培训在技术上非常容易,但是如果新数据与以前看到的数据有很大不同,则这可能会使您的模型忘记先前的学习。