Tensorflow cifar同步点

时间:2017-05-03 17:11:10

标签: tensorflow deep-learning

阅读功能average_gradients的{​​{3}},提供了以下评论:Note that this function provides a synchronization point across all towers.功能average_gradients是阻止呼叫,synchronization point是什么意思?

我认为这是一个阻塞调用,因为为了计算每个梯度必须单独计算的梯度的平均值?但是,对于所有单独的梯度计算,阻塞代码在哪里等待?

1 个答案:

答案 0 :(得分:6)

average_gradients本身不是阻止函数。它可能是具有张量流操作的另一个函数,这仍然是一个同步点。阻止它的原因是它使用了参数 tower_grads ,这取决于前一个for循环中创建的所有图形。

基本上这里发生的是创建训练图。首先,在for循环for i in xrange(FLAGS.num_gpus)中创建了几个图形“线程”。每个看起来像这样:

计算损失 - >计算梯度 - >附加到tower_grads

每个图形“线程”都通过with tf.device('/gpu:%d' % i)分配给不同的gpu,并且每个图形可以彼此独立运行(并且稍后将并行运行)。现在,下次使用tower_grads而没有设备规范时,它会在主设备上创建一个图表延续,将所有这些单独的图形“线程”绑定到一个图形。在运行tower_grads函数内的图之前,Tensorflow将确保完成average_gradients创建的每个图形“线程”。因此,稍后调用sess.run([train_op, loss])时,这将是图表的同步点。