默认情况下,TensorFlow分布式培训建立了工作者和参数服务器之间的所有连接,即使在异步分布式培训中,每个工作者和参数服务器之间唯一必要的通信。
当我使用tf.contrib.learn.Experiment
?
答案 0 :(得分:2)
# The easiest way to parse TF_CONFIG environment variable is to create a RunConfig.
# Unfortunately, it is an immutable object, so we're going to create a
# temporary one and only use it for `task_type` and `task_id`.
tmp = tf.contrib.learn.RunConfig()
task_type, task_id = tmp.task_type, tmp.task_id
# We use a device_filter to limit the communication between this job
# and the parameter servers, i.e., there is no need to directly
# communicate with the other workers; attempting to do so can result
# in reliability problems.
device_filters = [
'/job:ps', '/job:%s/task:%d' % (task_type, task_id)
]
session_config = tf.ConfigProto(device_filters=device_filters)
run_config = tf.contrib.learn.RunConfig(
model_dir=args.job_dir,
session_config=session_config)
# Create the experiment_fn:
experiment_fn = ...
# Run the experiment
learn_runner.run(experiment_fn, run_config=run_config)