在Distributed Tensorflow中批量处理数据

时间:2017-06-19 18:25:51

标签: tensorflow distributed

我对Tensorflow很新,所以我的问题可能听起来很愚蠢,但我真的找不到合适的解释,所以在这里问一下。 我需要您的帮助,以了解如何在图表分布式Tensorflow程序中进行数据批处理或分发。

由于我们执行多个客户端,它们基本上具有相同的代码以获得下一批:

batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)

我无法理解如何确保非常工作的唯一批次。对我而言,似乎正在向所有工人发送相同的数据。

在这个示例脚本中,每次迭代我们都在读next_batch,因为我们运行了两个带有job_type = worker的客户端,所以这两个worker都会看到相同的next_batch代码。请帮助我理解在这种情况下数据并行性将如何工作。

 with sv.prepare_or_wait_for_session(server.target, config=sess_config) as sess:
        print("Worker %d: Session initialization complete." % FLAGS.task_index)
        # Loop until the supervisor shuts down or 1000000 steps have completed.
        step = 0
        while not sv.should_stop() and step < 1000000:
            # Run a training step asynchronously.
            batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
            print("FETCHING NEXT BATCH %d" % FLAGS.batch_size)
            train_feed = {x: batch_xs, y_: batch_ys}

            _, step = sess.run([train_op, global_step], feed_dict=train_feed)
            if step % 100 == 0:
                print("Done step %d" % step)

    # Ask for all the services to stop.
    sv.stop()

期待您的帮助。

1 个答案:

答案 0 :(得分:2)

查看mnist.train.next_batchnext_batchin tensorflow.contrib.learn.python.learn.datasets.mnist)的代码,该代码是由mnist.train.next_batch调用的函数: - 每个工作人员都有一个单独的DataSet对象,用于生成数据。因此,每个批次将为每个工人独立生成。

  • 每个数据点可以在每个时期跨工人使用多次,但子集是随机获取的,因此可能不是问题。即使一些工作人员可能看到相同的数据点,批次本身也是随机生成的