分布式Tensorflow - 将图表传递给其他工作人员

时间:2017-10-18 16:50:52

标签: tensorflow

我想实现这样的目标: 有一个主要脚本,我设置了什么模型配置(超参数,隐藏层数等)我想训练和其他脚本只是用他们的批量数据训练定义的图形/模型,所以我在主工作者上创建一个图表,然后通过它没有在他们上面创建图表给其他工人。神经网络的配置(步数,学习率,隐藏层,隐藏层中的神经元)应该只在主脚本中知道。

类似的东西:

main.py:  # should be run on master worker
  learning_rates = [ 0.1, 0.5 ]
  hidden_layers = [ 1, 2, 3 ]
  for lr in learning_rates:
     for hl in hidden_layers:
        nn_train(lr, hl, x_train, y_train, cluster, job_name, task_index)

而nn_train应该是在TensorFlow的帮助下编写的神经网络模型脚本,并支持数据并行训练。 其他工人应该有类似的东西:

main_slave.py:
    nn_train(None, None, x_another_train, y_another_train, cluster, job_name, task_index)

脚本可能类似于:

nn_train.py:
    # Create and start a server for the local task.
    server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)

    if job_name == "ps":
       server.join()
       return

   if job_name != "worker":
       raise 'Unknown job name: ' + job_name

   if lr and hl: # if they are defined, it's master worker, build the graph
       # Assigns ops to the local worker by default.
       with tf.device(tf.train.replica_device_setter(...)):
           train_op = ... # build whole model, ie graph
    else:
        train_op = ... # **somehow get the graph, how?**

    with tf.train.MonitoredTrainingSession(master=server.target,..) as sess:
       while not mon_sess.should_stop():
       # Run a training step asynchronously.
       # perform *synchronous* training.
       # sess.run handles AbortedError in case of preempted PS.
       sess.run(train_op)

因此,如果在master worker上启动nn_train,那么它会根据给定的参数在那里构建一个图形。如果nn_train在其他工作者上启动,他们只是以某种方式获得已经定义的图并训练它。我假设每台机器都有它自己的数据集,它将在nn_train中传递。

我希望我转达了我的问题。

任何帮助都将受到高度赞赏!

提前谢谢。

0 个答案:

没有答案