我使用tensorflow版本0.12.1,然后使用this doc。
我想要做的是在每个工作人员中为count
添加1。
我的目标是打印> 1
的结果,但我只获得1
。
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('job_name', '', '')
tf.app.flags.DEFINE_string('ps_hosts', '','')
tf.app.flags.DEFINE_string('worker_hosts', '','')
tf.app.flags.DEFINE_integer('task_index', 0, '')
ps_hosts = FLAGS.ps_hosts.split(',')
worker_hosts = FLAGS.worker_hosts.split(',')
cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,'worker': worker_hosts})
server = tf.train.Server(
{'ps': ps_hosts,'worker': worker_hosts},
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
if FLAGS.job_name == 'ps':
server.join()
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster_spec)):
count = tf.Variable(0)
count = tf.add(count,tf.constant(1))
init = tf.global_variables_initializer()
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
logdir="./checkpoint/",
init_op=init,
summary_op=None,
saver=None,
global_step=None,
save_model_secs=60)
with sv.managed_session(server.target) as sess:
sess.run(init)
step = 1
while step <= 999999999:
result = sess.run(count)
if step%10000 == 0:
print(result)
if result>=2:
print("!!!!!!!!")
step += 1
print("Finished!")
sv.stop()
答案 0 :(得分:0)
问题实际上与分布式执行无关,并且源于这两行:
count = tf.Variable(0)
count = tf.add(count,tf.constant(1))
tf.add()
op是一个纯函数op,它在每次运行时创建一个带有输出的新张量,而不是修改它的输入。如果您希望值增加,并且该增加在工作人员中可见,则必须使用tf.Variable.assign_add()
方法,如下所示:
count = tf.Variable(0)
increment_count = count.assign_add(1)
然后在训练循环中调用sess.run(increment_count)
以增加count
变量的值。