分布式Tensorflow如何处理这种tf.Variable创建?

时间:2017-12-19 22:54:13

标签: python tensorflow distributed-computing

以下是分布式Tensorflow代码的两个版本,它们(尝试)实现一个全局计数器,该计数器存储在一个参数服务器上并由每个工作程序递增(异步)。

两个版本似乎打印相同的东西,但我不明白这个的原因。版本之间的差异在两行上由注释# NEW表示。

当每个工作人员运行版本1时,参数服务器是否会为每个工作人员自动存储local_counter tf.Variable个?

在版本2中,我尝试将每个local_counter tf.Variable显式放在参数服务器上。

跟随版本1或版本2是否确实有所作为?

PS :我确信这不是管理所有实例之间共享的tf.Variable的最佳方式,因此我很乐意接受有关改进的任何建议。谢谢!

版本1

# Standard distributed Tensorflow boilerplate
# ...

elif FLAGS.job_name == 'worker':

    TASK = FLAGS.task_index

    with tf.device('/job:ps/task:0/cpu:0'):
      with tf.variable_scope('global'):          
        global_counter = tf.Variable(0, name='global_counter',
                                        trainable=False)
        local_counter = tf.Variable(0, name='local_counter_{}'.format(TASK),
                                       trainable=False)
        init_op = tf.global_variables_initializer()

    with tf.device('/job:worker/task:{}'.format(TASK)): 
      with tf.variable_scope('local'):    
        local_inc_op = local_counter.assign_add(1) 
        global_inc_op = global_counter.assign_add(1)

    with tf.Session(server.target):
      sess.run(init_op)
      global_count = 0
      while global_count < 1000:
        sess.run([local_inc_op, global_inc_op])
        local_count, global_count = sess.run([local_counter, global_counter])
        print('Local {}, Global {}, worker-{}'.format(
               local_count, global_count, TASK))

版本2

# Standard distributed Tensorflow boilerplate
# ...

elif FLAGS.job_name == 'worker':

    NUM_WORKERS = len(worker_hosts)
    TASK = FLAGS.task_index

    with tf.device('/job:ps/task:0/cpu:0'):
      with tf.variable_scope('global'):          
        global_counter = tf.Variable(0, name='global_counter',
                                        trainable=False)
        local_counters = [tf.Variable(0, name='local_counter_{}'.format(i),
                                          trainable=False)
                          for i in range(NUM_WORKERS)] # NEW
        init_op = tf.global_variables_initializer()

    with tf.device('/job:worker/task:{}'.format(TASK)): 
      with tf.variable_scope('local'):
        local_counter = local_counters[TASK] # NEW    
        local_inc_op = local_counter.assign_add(1) 
        global_inc_op = global_counter.assign_add(1)

    with tf.Session(server.target):
      sess.run(init_op)
      global_count = 0
      while global_count < 1000:
        sess.run([local_inc_op, global_inc_op])
        local_count, global_count = sess.run([local_counter, global_counter])
        print('Local {}, Global {}, worker-{}'.format(
               local_count, global_count, TASK))

1 个答案:

答案 0 :(得分:1)

我不确定我看到了多少实际差异。在任何一种情况下,本地计数器都在参数服务device范围内创建,因此它们将存在于参数服务器上。在版本1中,每个工作者的图形仅包含其本地计数器,而每个工作者的图形包含版本2中的所有本地计数器(但工作者仍然只与他们自己的计数器交互,并且变量本身仍然住在参数服务器上。)

因此,要明确回答,是的,您可以存储参数服务器上参数服务器图形中不存在的变量。基本上,参数服务器(ResourceMgr)上的哈希表可以存储任意变量名称/值。

要自动将变量放在参数服务器上,tf.train.replica_device_setter可以帮助减少样板。