创建会话会在分布式tensorflow上挂起

时间:2017-02-02 11:18:22

标签: python-2.7 tensorflow multiprocessing distributed

我正在通过tensorflow分布式教程:

https://www.tensorflow.org/how_tos/distributed/

我正在尝试使用python多处理进程,而不是在单独的命令shell中启动脚本。不幸的是,代码在打开会话的阶段就会挂起。

欢迎任何想法。我准备了最小的代码示例,它基本上只启动了几个并行进程:

import tensorflow as tf
import time

from multiprocessing import Process

N_WORKERS = 3
SPEC = {'ps': ['127.0.0.1:12222'], 'worker': ['127.0.0.1:12223', '127.0.0.1:12224', '127.0.0.1:12225']}

def run_ps_server():
    spec = tf.train.ClusterSpec(SPEC)
    ps_server = tf.train.Server(spec, job_name='ps', task_index=0)
    ps_server.join()

def run_worker(task):
    spec = tf.train.ClusterSpec(SPEC)
    server = tf.train.Server(spec, job_name='worker', task_index=task)
    with tf.device(tf.train.replica_device_setter(1, worker_device="/job:worker/task:%d" % task)):
        global_step = tf.get_variable('global_step', [],
                                      initializer = tf.constant_initializer(0),
                                      trainable = False)
        inc_global_step = tf.assign_add(global_step, 1)
        init_op = tf.global_variables_initializer()

    sv = tf.train.Supervisor(is_chief=(task == 0),
                             global_step=global_step,
                             init_op=init_op)
    config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(task)])

    with sv.managed_session(server.target, config=config) as sess, sess.as_default():
        print 'task {}, global_step {}'.format(task, sess.run(global_step))
        if task == 0:
            sess.run(inc_global_step)
        elif task == 1:
            sess.run(inc_global_step)
            sess.run(inc_global_step)
        print 'task {}, global_step {}'.format(task, sess.run(global_step))

    if task == 2:
        sv.stop()


def main(_):
    ps_worker = Process(target=run_ps_server, args=())
    ps_worker.daemon = True
    ps_worker.start()

    worker_processes = []
    for i in xrange(N_WORKERS):
        time.sleep(0.01)
        w = Process(target=run_worker, args=(i,))
        w.daemon = True
        w.start()
        worker_processes.append(w)
    for w in worker_processes: w.join()

    ps_worker.terminate()

if __name__ == '__main__':
    tf.app.run()

Python 2.7。 Tensorflow 0.12.1(CPU),Mint 17(Ubuntu x64)

编辑:

问题在Tensorflow CUDA版本上没有重现。

0 个答案:

没有答案