SageMaker TF 2.3 分布式训练

时间:2021-03-16 13:24:10

标签: python tensorflow amazon-sagemaker

使用 SageMaker v2.29.2 和 Tensorflow v2.3.2,我正在尝试实施分布式训练,如以下博文所述:

https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-customize-training-script-tf.html#model-parallel-customize-training-script-tf-23

但是我在导入 smdistributed 脚本时遇到了困难。

这是我的代码:

import tensorflow as tf
import smdistributed.modelparallel.tensorflow as smp

错误:

Traceback (most recent call last):
  File "temp.py", line 2, in <module>
    import smdistributed.modelparallel.tensorflow as smp
ModuleNotFoundError: No module named 'smdistributed'

我错过了什么?

1 个答案:

答案 0 :(得分:0)

smdistributed 仅适用于 SageMaker 容器。它支持特定的 TensorFlow 版本,您必须添加:

distribution={'smdistributed': {
            'dataparallel': {
                'enabled': True
            }
        }}

在估算器代码上启用它