我对Tensorflow很新,所以我的问题可能听起来很愚蠢,但我真的找不到合适的解释,所以在这里问一下。 我需要您的帮助,以了解如何在图表分布式Tensorflow程序中进行数据批处理或分发。
由于我们执行多个客户端,它们基本上具有相同的代码以获得下一批:
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
我无法理解如何确保非常工作的唯一批次。对我而言,似乎正在向所有工人发送相同的数据。
在这个示例脚本中,每次迭代我们都在读next_batch,因为我们运行了两个带有job_type = worker的客户端,所以这两个worker都会看到相同的next_batch代码。请帮助我理解在这种情况下数据并行性将如何工作。
with sv.prepare_or_wait_for_session(server.target, config=sess_config) as sess:
print("Worker %d: Session initialization complete." % FLAGS.task_index)
# Loop until the supervisor shuts down or 1000000 steps have completed.
step = 0
while not sv.should_stop() and step < 1000000:
# Run a training step asynchronously.
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
print("FETCHING NEXT BATCH %d" % FLAGS.batch_size)
train_feed = {x: batch_xs, y_: batch_ys}
_, step = sess.run([train_op, global_step], feed_dict=train_feed)
if step % 100 == 0:
print("Done step %d" % step)
# Ask for all the services to stop.
sv.stop()
期待您的帮助。
答案 0 :(得分:2)
查看mnist.train.next_batch
和next_batch
(in tensorflow.contrib.learn.python.learn.datasets.mnist)的代码,该代码是由mnist.train.next_batch调用的函数:
- 每个工作人员都有一个单独的DataSet
对象,用于生成数据。因此,每个批次将为每个工人独立生成。