prefetch_queue在cifar10_multi_gpu_train.py中使用

时间:2017-06-30 10:31:27

标签: python tensorflow prefetch

当我将tensorflow代码修改为multi-gpu方式时,我发现每个GPU都从prefetch_queue接收相同的批处理数据。这意味着GPU上的模型接收相同的数据,并给出相同的损失值。关于prefetch_queue.dequeue的代码如下:

with tf.variable_scope(tf.get_variable_scope()):
      for i in xrange(FLAGS.num_gpus):
        with tf.device('/gpu:%d' % i):
          with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
            # Dequeues one batch for the GPU
            image_batch, label_batch = batch_queue.dequeue()

0 个答案:

没有答案