如何完成这个非常简单的分布式培训示例?

时间:2017-02-28 09:28:10

标签: tensorflow

我使用tensorflow版本0.12.1,然后使用this doc

我想要做的是在每个工作人员中为count添加1。

我的目标是打印> 1的结果,但我只获得1

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('job_name', '', '')
tf.app.flags.DEFINE_string('ps_hosts', '','')
tf.app.flags.DEFINE_string('worker_hosts', '','')
tf.app.flags.DEFINE_integer('task_index', 0, '')

ps_hosts = FLAGS.ps_hosts.split(',')
worker_hosts = FLAGS.worker_hosts.split(',')
cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,'worker': worker_hosts})
server = tf.train.Server(
                    {'ps': ps_hosts,'worker': worker_hosts},
                    job_name=FLAGS.job_name,
                    task_index=FLAGS.task_index)

if FLAGS.job_name == 'ps':
  server.join()

with tf.device(tf.train.replica_device_setter(
               worker_device="/job:worker/task:%d" % FLAGS.task_index,
               cluster=cluster_spec)):
  count = tf.Variable(0)
  count = tf.add(count,tf.constant(1))
  init = tf.global_variables_initializer()

sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                            logdir="./checkpoint/",
                            init_op=init,
                            summary_op=None,
                            saver=None,
                            global_step=None,
                            save_model_secs=60)

with sv.managed_session(server.target) as sess:
    sess.run(init)
    step = 1
    while step <= 999999999:
        result = sess.run(count)
        if step%10000 == 0:
          print(result)
        if result>=2:
          print("!!!!!!!!")
        step += 1
    print("Finished!")

sv.stop() 

1 个答案:

答案 0 :(得分:0)

问题实际上与分布式执行无关,并且源于这两行:

  count = tf.Variable(0)
  count = tf.add(count,tf.constant(1))

tf.add() op是一个纯函数op,它在每次运行时创建一个带有输出的新张量,而不是修改它的输入。如果您希望值增加,并且该增加在工作人员中可见,则必须使用tf.Variable.assign_add()方法,如下所示:

  count = tf.Variable(0)
  increment_count = count.assign_add(1)

然后在训练循环中调用sess.run(increment_count)以增加count变量的值。