共享RNN状态分布式TensorFlow

时间:2017-04-03 06:11:40

标签: tensorflow recurrent-neural-network tensorflow-serving

我尝试在多台服务器上使用TensorFlow实现同步分布式回归神经网络。这是我的代码的链接: LegacyMixin

我不确定如何让_current_state变量在所有服务器之间共享,您可以在"培训循环中找到它们#34;我的代码部分。我还提供了以下代码。我希望同一批次中的计算并行发生,但我认为它仍然在每个工作服务器上计算单独的RNN并分别更新参数服务器上的参数。我知道这是因为我在为每个批次运行图表后打印_current_state变量。此外,同一全局步骤的_total_loss在每个工作服务器上都不同。

我按照以下链接提供的说明操作:https://github.com/tushar00jain/spark-ml/blob/master/rnn-sync.ipynb https://www.tensorflow.org/deploy/distributed#replicated_training

        sess = sv.prepare_or_wait_for_session(server.target)
        queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
        sv.start_queue_runners(sess, queue_runners)

        tf.logging.info('Started %d queues for processing input data.',
                        len(queue_runners))

        if is_chief:
                sv.start_queue_runners(sess, chief_queue_runners)
                sess.run(init_tokens_op)

        print("{0} session ready".format(datetime.now().isoformat()))
        #####################################################################

        ########################### training loop ###########################
        _current_state = np.zeros((batch_size, state_size))
        for batch_idx in range(args.steps):
            if sv.should_stop() or tf_feed.should_stop():
                break

            batchX, batchY = feed_dict(tf_feed.next_batch(batch_size))

            print('==========================================================')
            print(_current_state)

            if args.mode == "train":
                _total_loss, _train_step, _current_state, _predictions_series, _global_step = sess.run(
                [total_loss, train_step, current_state, predictions_series, global_step],
                feed_dict={
                    batchX_placeholder:batchX,
                    batchY_placeholder:batchY,
                    init_state:_current_state
                })

                print(_global_step, batch_idx)
                print(_current_state)
                print('==========================================================')

                if _global_step % 5 == 0:
                    print("Step", _global_step, "Loss", _total_loss)  

如何共享_current_state并对其进行更新,以便所有服务器具有相同global_step的相同状态?

0 个答案:

没有答案