阅读功能average_gradients
的{{3}},提供了以下评论:Note that this function provides a synchronization point across all towers.
功能average_gradients
是阻止呼叫,synchronization point
是什么意思?
我认为这是一个阻塞调用,因为为了计算每个梯度必须单独计算的梯度的平均值?但是,对于所有单独的梯度计算,阻塞代码在哪里等待?
答案 0 :(得分:6)
average_gradients
本身不是阻止函数。它可能是具有张量流操作的另一个函数,这仍然是一个同步点。阻止它的原因是它使用了参数 tower_grads
,这取决于前一个for循环中创建的所有图形。
基本上这里发生的是创建训练图。首先,在for循环for i in xrange(FLAGS.num_gpus)
中创建了几个图形“线程”。每个看起来像这样:
计算损失 - >计算梯度 - >附加到tower_grads
每个图形“线程”都通过with tf.device('/gpu:%d' % i)
分配给不同的gpu,并且每个图形可以彼此独立运行(并且稍后将并行运行)。现在,下次使用tower_grads
而没有设备规范时,它会在主设备上创建一个图表延续,将所有这些单独的图形“线程”绑定到一个图形。在运行tower_grads
函数内的图之前,Tensorflow将确保完成average_gradients
创建的每个图形“线程”。因此,稍后调用sess.run([train_op, loss])
时,这将是图表的同步点。