在分布式张量流中形成障碍的正确方法是什么?

时间:2016-09-22 12:00:16

标签: concurrency tensorflow

在分布式训练期间,我想在每个时期之后进行同步,对主要工作人员进行一些计算,然后根据这些计算进行或停止训练。我需要一个屏障才能这样做。

我在文档中没有看到任何类似内容,因此我实现了基于队列的解决方案(类似于渐变存储和应用于分布式培训的方式):

def build_barrier(tasks, task_index, barrier_name):
    queues = []
    for i, task in enumerate(tasks):
        with tf.device('%s/cpu:0' % task):
            with tf.name_scope(barrier_name):
                queues.append(
                    tf.FIFOQueue(
                        len(tasks),
                        (tf.float32),
                        shapes=(()),
                        name=str(i),
                        shared_name=str(i)))

    with tf.control_dependencies([queue.enqueue(1.) for queue in queues]):
        return queues[task_index].dequeue_many(len(tasks))

这个想法是为每个工人创建一个队列。对于'signal',我在每个队列中推送一个令牌,并且为了'join',我从相应的队列中取出了许多令牌,我想要同步多少个任务。

问题是:这是正确的做法还是有更好的方式?

2 个答案:

答案 0 :(得分:2)

您的解决方案与SyncReplicasOptimizer非常相似。在SyncReplicasOptimizer中,它使用同步令牌队列来模拟屏障,并使用每个变量的累加器来累积和平均grad更新。它是一种非常典型的批量同步并行,同时还有在Tensorflow中实现过时同步并行的附加功能。

此外,Tensorflow在最新版本中提供Barrier,您可以查看更多信息。

答案 1 :(得分:0)

这是一个模拟张量的纯张量流解决方案。请注意使用两个队列,因为张量流似乎没有合适的解决方案来在分布式会话中自动增加变量,但是queue.size()确实满足了这一要求:

def tf_barrier(shared_name: str, n_workers: int):
    passing_q = tf.FIFOQueue(n_workers, tf.bool, (), shared_name=shared_name + '_count_q')
    blocking_q = tf.FIFOQueue(n_workers, tf.bool, (), shared_name=shared_name + '_barrier_q')
    increment_size = passing_q.enqueue(True) # Atomically increment queue size
    with tf.control_dependencies([increment_size]):
        incremented_size = passing_q.size()
        return tf.cond(tf.equal(incremented_size, n_workers),
                       lambda: tf.group([blocking_q.enqueue_many([[True] * n_workers]), passing_q.dequeue_many(n_workers)]),
                       lambda: blocking_q.dequeue()
                       )

尽管内部复杂,但使用起来非常简单!

with create_session(job.name, task_index) as sess: # Assume 6 workers
    start_barrier = tf_barrier('start', 6)
    sess.run(start_barrier)
    # Every 6th run of start_barrier unblocks the 5 runs before it