如何将device_filters与tf.contrib.learn.Experiment一起使用?

时间:2017-10-23 21:13:43

标签: tensorflow google-cloud-ml-engine

默认情况下,TensorFlow分布式培训建立了工作者和参数服务器之间的所有连接,即使在异步分布式培训中,每个工作者和参数服务器之间唯一必要的通信。

当我使用tf.contrib.learn.Experiment

时,如何限制沟通?

1 个答案:

答案 0 :(得分:2)

# The easiest way to parse TF_CONFIG environment variable is to create a RunConfig.
# Unfortunately, it is an immutable object, so we're going to create a
# temporary one and only use it for `task_type` and `task_id`.
tmp = tf.contrib.learn.RunConfig()
task_type, task_id = tmp.task_type, tmp.task_id

# We use a device_filter to limit the communication between this job
# and the parameter servers, i.e., there is no need to directly
# communicate with the other workers; attempting to do so can result
# in reliability problems.
device_filters = [
    '/job:ps', '/job:%s/task:%d' % (task_type, task_id)
]
session_config = tf.ConfigProto(device_filters=device_filters)
run_config = tf.contrib.learn.RunConfig(
    model_dir=args.job_dir,
    session_config=session_config)

# Create the experiment_fn:
experiment_fn = ...

# Run the experiment
learn_runner.run(experiment_fn, run_config=run_config)