我正在尝试在分布式TensorFlow集群中的不同主机上初始化全局变量。问题是如何通过本地会话而不是主要会话来完成初始化。
设置了两个主机群集。在每个主机上,将创建一个s3_resource = boto3.resource('s3')
for bucket in s3_resource.buckets.all():
# Do something with bucket
来监督初始化。当创建MonitoredSession
而不包含Scaffold
时,所有全局变量都由init_op
初始化,并且观察到mon_sess
值的竞争条件。当提供本地主机上的weight
中的init_op
时,将引发“未初始化变量”的RuntimeError。
weight
Python版本:3.6.8,TensorFlow版本:1.13.1。启动代码的脚本:
import sys
import argparse
import tensorflow as tf
FLAGS = None
def test_distribute(_):
"""Test distribute communication"""
worker_hosts = FLAGS.worker_hosts.split(',')
cluster_spec = tf.train.ClusterSpec({'worker': worker_hosts})
tf.set_random_seed(FLAGS.task_index)
# define model
init_op = list()
weight_list = list()
cluster = cluster_spec.as_dict()
device_fmt = '/job:worker/task:{}/device:CPU:0'
scope_fmt = 'worker{}'
for task, _ in enumerate(cluster['worker']):
with tf.variable_scope(scope_fmt.format(task)), \
tf.device(device_fmt.format(task)):
weight = tf.get_variable('weight', shape=[])
weight_list.append(weight)
if task == FLAGS.task_index:
init_op.append(weight.initializer)
server = tf.train.Server(cluster_spec,
job_name='worker', task_index=FLAGS.task_index)
# scaffold = tf.train.Scaffold(init_op=init_op)
scaffold = tf.train.Scaffold()
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold, master=server.target)
with tf.train.MonitoredSession(session_creator=session_creator) as mon_sess:
result = mon_sess.run(weight_list)
print('worker {} get weights {}'.format(FLAGS.task_index, result))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--worker_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=test_distribute, argv=[sys.argv[0]] + unparsed)
提供(python tf_distribute.py --worker_hosts=localhost:40000,localhost:40001 --task_index=0) &(python tf_distribute.py --worker_hosts=localhost:40000,localhost:40001 --task_index=1)
时,两个工作程序的输出通常相同:
init_op
或:
worker 0 get weights [0.9657415, 0.94129646]
worker 1 get weights [0.9657415, 0.94129646]
但偶尔会有所不同:
worker 1 get weights [-0.62132096, 0.1889062]
worker 0 get weights [-0.62132096, 0.1889062]
当省略worker 0 get weights [-0.62132096, 0.1889062]
worker 1 get weights [0.9657415, 0.94129646]
时,可以给出前一个输出。不过,通常会引发RuntimeError。
init_op