我尝试在多台服务器上使用TensorFlow实现同步分布式回归神经网络。这是我的代码的链接: LegacyMixin
我不确定如何让_current_state
变量在所有服务器之间共享,您可以在"培训循环中找到它们#34;我的代码部分。我还提供了以下代码。我希望同一批次中的计算并行发生,但我认为它仍然在每个工作服务器上计算单独的RNN并分别更新参数服务器上的参数。我知道这是因为我在为每个批次运行图表后打印_current_state
变量。此外,同一全局步骤的_total_loss
在每个工作服务器上都不同。
我按照以下链接提供的说明操作:https://github.com/tushar00jain/spark-ml/blob/master/rnn-sync.ipynb https://www.tensorflow.org/deploy/distributed#replicated_training
sess = sv.prepare_or_wait_for_session(server.target)
queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
sv.start_queue_runners(sess, queue_runners)
tf.logging.info('Started %d queues for processing input data.',
len(queue_runners))
if is_chief:
sv.start_queue_runners(sess, chief_queue_runners)
sess.run(init_tokens_op)
print("{0} session ready".format(datetime.now().isoformat()))
#####################################################################
########################### training loop ###########################
_current_state = np.zeros((batch_size, state_size))
for batch_idx in range(args.steps):
if sv.should_stop() or tf_feed.should_stop():
break
batchX, batchY = feed_dict(tf_feed.next_batch(batch_size))
print('==========================================================')
print(_current_state)
if args.mode == "train":
_total_loss, _train_step, _current_state, _predictions_series, _global_step = sess.run(
[total_loss, train_step, current_state, predictions_series, global_step],
feed_dict={
batchX_placeholder:batchX,
batchY_placeholder:batchY,
init_state:_current_state
})
print(_global_step, batch_idx)
print(_current_state)
print('==========================================================')
if _global_step % 5 == 0:
print("Step", _global_step, "Loss", _total_loss)
如何共享_current_state
并对其进行更新,以便所有服务器具有相同global_step
的相同状态?